2222
2323// cxx
2424#include < cxx/ast.h>
25+ #include < cxx/ast_cursor.h>
2526#include < cxx/binder.h>
2627#include < cxx/control.h>
2728#include < cxx/decl.h>
3132#include < cxx/type_checker.h>
3233#include < cxx/types.h>
3334
35+ #include < format>
36+
3437namespace cxx {
3538
3639ASTRewriter::ASTRewriter (TypeChecker* typeChcker,
@@ -54,6 +57,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
5457 restrictedToDeclarations_ = restrictedToDeclarations;
5558}
5659
60+ auto ASTRewriter::getParameterPack (ExpressionAST* ast) -> ParameterPackSymbol* {
61+ for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
62+ const auto & current = *cursor;
63+ if (!std::holds_alternative<AST*>(current.node )) continue ;
64+
65+ auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node ));
66+ if (!id) continue ;
67+
68+ auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol );
69+ if (!param) continue ;
70+
71+ if (param->depth () != 0 ) continue ;
72+
73+ auto arg = templateArguments_[param->index ()];
74+ auto argSymbol = std::get<Symbol*>(arg);
75+
76+ auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
77+ if (parameterPack) return parameterPack;
78+ }
79+
80+ return nullptr ;
81+ }
82+
5783struct ASTRewriter ::UnitVisitor {
5884 ASTRewriter& rewrite;
5985 [[nodiscard]] auto translationUnit () const -> TranslationUnit* {
@@ -2345,6 +2371,28 @@ auto ASTRewriter::ExpressionVisitor::operator()(NestedExpressionAST* ast)
23452371
23462372auto ASTRewriter::ExpressionVisitor::operator ()(IdExpressionAST* ast)
23472373 -> ExpressionAST* {
2374+ if (auto param = symbol_cast<NonTypeParameterSymbol>(ast->symbol );
2375+ param && param->depth () == 0 &&
2376+ param->index () < rewrite.templateArguments_ .size ()) {
2377+ auto symbolPtr =
2378+ std::get_if<Symbol*>(&rewrite.templateArguments_ [param->index ()]);
2379+
2380+ if (!symbolPtr) {
2381+ cxx_runtime_error (" expected initializer for non-type template parameter" );
2382+ }
2383+
2384+ auto parameterPack = symbol_cast<ParameterPackSymbol>(*symbolPtr);
2385+
2386+ if (parameterPack && parameterPack == rewrite.parameterPack_ &&
2387+ rewrite.elementIndex_ .has_value ()) {
2388+ auto idx = rewrite.elementIndex_ .value ();
2389+ auto element = parameterPack->elements ()[idx];
2390+ if (auto var = symbol_cast<VariableSymbol>(element)) {
2391+ return rewrite (var->initializer ());
2392+ }
2393+ }
2394+ }
2395+
23482396 auto copy = make_node<IdExpressionAST>(arena ());
23492397
23502398 copy->valueCategory = ast->valueCategory ;
@@ -2355,15 +2403,17 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23552403
23562404 copy->symbol = ast->symbol ;
23572405
2358- if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol );
2359- x && x->depth () == 0 && x->index () < rewrite.templateArguments_ .size ()) {
2360- auto initializerPtr =
2361- std::get_if<Symbol*>(&rewrite.templateArguments_ [x->index ()]);
2362- if (!initializerPtr) {
2406+ if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol );
2407+ param && param->depth () == 0 &&
2408+ param->index () < rewrite.templateArguments_ .size ()) {
2409+ auto symbolPtr =
2410+ std::get_if<Symbol*>(&rewrite.templateArguments_ [param->index ()]);
2411+
2412+ if (!symbolPtr) {
23632413 cxx_runtime_error (" expected initializer for non-type template parameter" );
23642414 }
23652415
2366- copy->symbol = *initializerPtr ;
2416+ copy->symbol = *symbolPtr ;
23672417 copy->type = copy->symbol ->type ();
23682418 }
23692419
@@ -2473,6 +2523,41 @@ auto ASTRewriter::ExpressionVisitor::operator()(RightFoldExpressionAST* ast)
24732523
24742524auto ASTRewriter::ExpressionVisitor::operator ()(LeftFoldExpressionAST* ast)
24752525 -> ExpressionAST* {
2526+ if (auto parameterPack = rewrite.getParameterPack (ast->expression )) {
2527+ auto savedParameterPack = rewrite.parameterPack_ ;
2528+ std::swap (rewrite.parameterPack_ , parameterPack);
2529+
2530+ std::vector<ExpressionAST*> instantiations;
2531+ ExpressionAST* current = nullptr ;
2532+
2533+ int n = 0 ;
2534+ for (auto element : rewrite.parameterPack_ ->elements ()) {
2535+ std::optional<int > index{n};
2536+ std::swap (rewrite.elementIndex_ , index);
2537+
2538+ auto expression = rewrite (ast->expression );
2539+ if (!current) {
2540+ current = expression;
2541+ } else {
2542+ auto binop = make_node<BinaryExpressionAST>(arena ());
2543+ binop->valueCategory = current->valueCategory ;
2544+ binop->type = current->type ;
2545+ binop->leftExpression = current;
2546+ binop->op = ast->op ;
2547+ binop->opLoc = ast->opLoc ;
2548+ binop->rightExpression = expression;
2549+ current = binop;
2550+ }
2551+
2552+ std::swap (rewrite.elementIndex_ , index);
2553+ ++n;
2554+ }
2555+
2556+ std::swap (rewrite.parameterPack_ , parameterPack);
2557+
2558+ return current;
2559+ }
2560+
24762561 auto copy = make_node<LeftFoldExpressionAST>(arena ());
24772562
24782563 copy->valueCategory = ast->valueCategory ;
0 commit comments