Skip to content

Commit 5388c97

Browse files
committed
Simulate template argument pack expansion
1 parent 8e012c8 commit 5388c97

File tree

11 files changed

+229
-8
lines changed

11 files changed

+229
-8
lines changed

src/parser/cxx/ast_rewriter.cc

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
// cxx
2424
#include <cxx/ast.h>
25+
#include <cxx/ast_cursor.h>
2526
#include <cxx/binder.h>
2627
#include <cxx/control.h>
2728
#include <cxx/decl.h>
@@ -31,6 +32,8 @@
3132
#include <cxx/type_checker.h>
3233
#include <cxx/types.h>
3334

35+
#include <format>
36+
3437
namespace cxx {
3538

3639
ASTRewriter::ASTRewriter(TypeChecker* typeChcker,
@@ -54,6 +57,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
5457
restrictedToDeclarations_ = restrictedToDeclarations;
5558
}
5659

60+
auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
61+
for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
62+
const auto& current = *cursor;
63+
if (!std::holds_alternative<AST*>(current.node)) continue;
64+
65+
auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node));
66+
if (!id) continue;
67+
68+
auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol);
69+
if (!param) continue;
70+
71+
if (param->depth() != 0) continue;
72+
73+
auto arg = templateArguments_[param->index()];
74+
auto argSymbol = std::get<Symbol*>(arg);
75+
76+
auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
77+
if (parameterPack) return parameterPack;
78+
}
79+
80+
return nullptr;
81+
}
82+
5783
struct ASTRewriter::UnitVisitor {
5884
ASTRewriter& rewrite;
5985
[[nodiscard]] auto translationUnit() const -> TranslationUnit* {
@@ -2345,6 +2371,28 @@ auto ASTRewriter::ExpressionVisitor::operator()(NestedExpressionAST* ast)
23452371

23462372
auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23472373
-> ExpressionAST* {
2374+
if (auto param = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
2375+
param && param->depth() == 0 &&
2376+
param->index() < rewrite.templateArguments_.size()) {
2377+
auto symbolPtr =
2378+
std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
2379+
2380+
if (!symbolPtr) {
2381+
cxx_runtime_error("expected initializer for non-type template parameter");
2382+
}
2383+
2384+
auto parameterPack = symbol_cast<ParameterPackSymbol>(*symbolPtr);
2385+
2386+
if (parameterPack && parameterPack == rewrite.parameterPack_ &&
2387+
rewrite.elementIndex_.has_value()) {
2388+
auto idx = rewrite.elementIndex_.value();
2389+
auto element = parameterPack->elements()[idx];
2390+
if (auto var = symbol_cast<VariableSymbol>(element)) {
2391+
return rewrite(var->initializer());
2392+
}
2393+
}
2394+
}
2395+
23482396
auto copy = make_node<IdExpressionAST>(arena());
23492397

23502398
copy->valueCategory = ast->valueCategory;
@@ -2355,15 +2403,17 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23552403

23562404
copy->symbol = ast->symbol;
23572405

2358-
if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
2359-
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
2360-
auto initializerPtr =
2361-
std::get_if<Symbol*>(&rewrite.templateArguments_[x->index()]);
2362-
if (!initializerPtr) {
2406+
if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
2407+
param && param->depth() == 0 &&
2408+
param->index() < rewrite.templateArguments_.size()) {
2409+
auto symbolPtr =
2410+
std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
2411+
2412+
if (!symbolPtr) {
23632413
cxx_runtime_error("expected initializer for non-type template parameter");
23642414
}
23652415

2366-
copy->symbol = *initializerPtr;
2416+
copy->symbol = *symbolPtr;
23672417
copy->type = copy->symbol->type();
23682418
}
23692419

@@ -2473,6 +2523,41 @@ auto ASTRewriter::ExpressionVisitor::operator()(RightFoldExpressionAST* ast)
24732523

24742524
auto ASTRewriter::ExpressionVisitor::operator()(LeftFoldExpressionAST* ast)
24752525
-> ExpressionAST* {
2526+
if (auto parameterPack = rewrite.getParameterPack(ast->expression)) {
2527+
auto savedParameterPack = rewrite.parameterPack_;
2528+
std::swap(rewrite.parameterPack_, parameterPack);
2529+
2530+
std::vector<ExpressionAST*> instantiations;
2531+
ExpressionAST* current = nullptr;
2532+
2533+
int n = 0;
2534+
for (auto element : rewrite.parameterPack_->elements()) {
2535+
std::optional<int> index{n};
2536+
std::swap(rewrite.elementIndex_, index);
2537+
2538+
auto expression = rewrite(ast->expression);
2539+
if (!current) {
2540+
current = expression;
2541+
} else {
2542+
auto binop = make_node<BinaryExpressionAST>(arena());
2543+
binop->valueCategory = current->valueCategory;
2544+
binop->type = current->type;
2545+
binop->leftExpression = current;
2546+
binop->op = ast->op;
2547+
binop->opLoc = ast->opLoc;
2548+
binop->rightExpression = expression;
2549+
current = binop;
2550+
}
2551+
2552+
std::swap(rewrite.elementIndex_, index);
2553+
++n;
2554+
}
2555+
2556+
std::swap(rewrite.parameterPack_, parameterPack);
2557+
2558+
return current;
2559+
}
2560+
24762561
auto copy = make_node<LeftFoldExpressionAST>(arena());
24772562

24782563
copy->valueCategory = ast->valueCategory;

src/parser/cxx/ast_rewriter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,13 @@ class ASTRewriter {
151151
private:
152152
[[nodiscard]] auto rewriter() -> ASTRewriter* { return this; }
153153

154+
[[nodiscard]] auto getParameterPack(ExpressionAST* ast)
155+
-> ParameterPackSymbol*;
156+
154157
TypeChecker* typeChecker_ = nullptr;
155158
const std::vector<TemplateArgument>& templateArguments_;
159+
ParameterPackSymbol* parameterPack_ = nullptr;
160+
std::optional<int> elementIndex_;
156161
TranslationUnit* unit_ = nullptr;
157162
Binder binder_;
158163
bool restrictedToDeclarations_ = false;

src/parser/cxx/control.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ struct Control::Private {
143143
std::forward_list<VariableSymbol> variableSymbols;
144144
std::forward_list<FieldSymbol> fieldSymbols;
145145
std::forward_list<ParameterSymbol> parameterSymbols;
146+
std::forward_list<ParameterPackSymbol> parameterPackSymbols;
146147
std::forward_list<TypeParameterSymbol> typeParameterSymbols;
147148
std::forward_list<NonTypeParameterSymbol> nonTypeParameterSymbols;
148149
std::forward_list<TemplateTypeParameterSymbol> templateTypeParameterSymbols;
@@ -592,6 +593,13 @@ auto Control::newParameterSymbol(Scope* enclosingScope, SourceLocation loc)
592593
return symbol;
593594
}
594595

596+
auto Control::newParameterPackSymbol(Scope* enclosingScope, SourceLocation loc)
597+
-> ParameterPackSymbol* {
598+
auto symbol = &d->parameterPackSymbols.emplace_front(enclosingScope);
599+
symbol->setLocation(loc);
600+
return symbol;
601+
}
602+
595603
auto Control::newTypeParameterSymbol(Scope* enclosingScope, SourceLocation loc)
596604
-> TypeParameterSymbol* {
597605
auto symbol = &d->typeParameterSymbols.emplace_front(enclosingScope);

src/parser/cxx/control.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class Control {
202202
[[nodiscard]] auto newParameterSymbol(Scope* enclosingScope,
203203
SourceLocation sourceLocation)
204204
-> ParameterSymbol*;
205+
[[nodiscard]] auto newParameterPackSymbol(Scope* enclosingScope,
206+
SourceLocation sourceLocation)
207+
-> ParameterPackSymbol*;
205208
[[nodiscard]] auto newTypeParameterSymbol(Scope* enclosingScope,
206209
SourceLocation sourceLocation)
207210
-> TypeParameterSymbol*;

src/parser/cxx/external_name_encoder.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ struct ExternalNameEncoder::SymbolVisitor {
156156

157157
void operator()(ParameterSymbol* symbol) {}
158158

159+
void operator()(ParameterPackSymbol* symbol) {}
160+
159161
void operator()(EnumeratorSymbol* symbol) {}
160162

161163
void operator()(FunctionParametersSymbol* symbol) {}

src/parser/cxx/symbol_instantiation.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct SymbolInstantiation::VisitSymbol {
6262
[[nodiscard]] auto operator()(VariableSymbol* symbol) -> Symbol*;
6363
[[nodiscard]] auto operator()(FieldSymbol* symbol) -> Symbol*;
6464
[[nodiscard]] auto operator()(ParameterSymbol* symbol) -> Symbol*;
65+
[[nodiscard]] auto operator()(ParameterPackSymbol* symbol) -> Symbol*;
6566
[[nodiscard]] auto operator()(EnumeratorSymbol* symbol) -> Symbol*;
6667
[[nodiscard]] auto operator()(FunctionParametersSymbol* symbol) -> Symbol*;
6768
[[nodiscard]] auto operator()(TemplateParametersSymbol* symbol) -> Symbol*;
@@ -312,6 +313,12 @@ auto SymbolInstantiation::VisitSymbol::operator()(ParameterSymbol* symbol)
312313
return newSymbol;
313314
}
314315

316+
auto SymbolInstantiation::VisitSymbol::operator()(ParameterPackSymbol* symbol)
317+
-> Symbol* {
318+
auto newSymbol = self.replacement(symbol);
319+
return newSymbol;
320+
}
321+
315322
auto SymbolInstantiation::VisitSymbol::operator()(EnumeratorSymbol* symbol)
316323
-> Symbol* {
317324
auto newSymbol = self.replacement(symbol);

src/parser/cxx/symbol_printer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ struct DumpSymbols {
269269
to_string(symbol->type(), symbol->name()));
270270
}
271271

272+
void operator()(ParameterPackSymbol* symbol) {
273+
indent();
274+
out << std::format("parameter pack {}\n",
275+
to_string(symbol->type(), symbol->name()));
276+
}
277+
272278
void operator()(TypeParameterSymbol* symbol) {
273279
std::string_view pack = symbol->isParameterPack() ? "..." : "";
274280
indent();

src/parser/cxx/symbols.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,19 @@ ParameterSymbol::ParameterSymbol(Scope* enclosingScope)
615615

616616
ParameterSymbol::~ParameterSymbol() {}
617617

618+
ParameterPackSymbol::ParameterPackSymbol(Scope* enclosingScope)
619+
: Symbol(Kind, enclosingScope) {}
620+
621+
ParameterPackSymbol::~ParameterPackSymbol() {}
622+
623+
auto ParameterPackSymbol::elements() const -> const std::vector<Symbol*>& {
624+
return elements_;
625+
}
626+
627+
void ParameterPackSymbol::addElement(Symbol* element) {
628+
elements_.push_back(element);
629+
}
630+
618631
TypeParameterSymbol::TypeParameterSymbol(Scope* enclosingScope)
619632
: Symbol(Kind, enclosingScope) {}
620633

src/parser/cxx/symbols.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,20 @@ class ParameterSymbol final : public Symbol {
624624
~ParameterSymbol() override;
625625
};
626626

627+
class ParameterPackSymbol final : public Symbol {
628+
public:
629+
constexpr static auto Kind = SymbolKind::kParameterPack;
630+
631+
explicit ParameterPackSymbol(Scope* enclosingScope);
632+
~ParameterPackSymbol() override;
633+
634+
[[nodiscard]] auto elements() const -> const std::vector<Symbol*>&;
635+
void addElement(Symbol* element);
636+
637+
private:
638+
std::vector<Symbol*> elements_;
639+
};
640+
627641
class TypeParameterSymbol final : public Symbol {
628642
public:
629643
constexpr static auto Kind = SymbolKind::kTypeParameter;

src/parser/cxx/symbols_fwd.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace cxx {
3737
V(Variable) \
3838
V(Field) \
3939
V(Parameter) \
40+
V(ParameterPack) \
4041
V(Enumerator) \
4142
V(FunctionParameters) \
4243
V(TemplateParameters) \

0 commit comments

Comments
 (0)