2222
2323// cxx
2424#include < cxx/ast.h>
25+ #include < cxx/ast_cursor.h>
2526#include < cxx/binder.h>
2627#include < cxx/control.h>
2728#include < cxx/decl.h>
3132#include < cxx/type_checker.h>
3233#include < cxx/types.h>
3334
35+ #include < format>
36+
3437namespace cxx {
3538
3639ASTRewriter::ASTRewriter (TypeChecker* typeChcker,
@@ -54,6 +57,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
5457 restrictedToDeclarations_ = restrictedToDeclarations;
5558}
5659
60+ auto ASTRewriter::getParameterPack (ExpressionAST* ast) -> ParameterPackSymbol* {
61+ for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
62+ const auto & current = *cursor;
63+ if (!std::holds_alternative<AST*>(current.node )) continue ;
64+
65+ auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node ));
66+ if (!id) continue ;
67+
68+ auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol );
69+ if (!param) continue ;
70+
71+ if (param->depth () != 0 ) continue ;
72+
73+ auto arg = templateArguments_[param->index ()];
74+ auto argSymbol = std::get<Symbol*>(arg);
75+
76+ auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
77+ if (parameterPack) return parameterPack;
78+ }
79+
80+ return nullptr ;
81+ }
82+
5783struct ASTRewriter ::UnitVisitor {
5884 ASTRewriter& rewrite;
5985 [[nodiscard]] auto translationUnit () const -> TranslationUnit* {
@@ -2355,16 +2381,27 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23552381
23562382 copy->symbol = ast->symbol ;
23572383
2358- if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol );
2359- x && x->depth () == 0 && x->index () < rewrite.templateArguments_ .size ()) {
2384+ if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol );
2385+ param && param->depth () == 0 &&
2386+ param->index () < rewrite.templateArguments_ .size ()) {
23602387 auto initializerPtr =
2361- std::get_if<Symbol*>(&rewrite.templateArguments_ [x->index ()]);
2388+ std::get_if<Symbol*>(&rewrite.templateArguments_ [param->index ()]);
2389+
23622390 if (!initializerPtr) {
23632391 cxx_runtime_error (" expected initializer for non-type template parameter" );
23642392 }
23652393
23662394 copy->symbol = *initializerPtr;
23672395 copy->type = copy->symbol ->type ();
2396+
2397+ auto parameterPack = symbol_cast<ParameterPackSymbol>(copy->symbol );
2398+ if (parameterPack && rewrite.elementIndex_ .has_value ()) {
2399+ // ### TODO: check that idx is related to this non-type parameter.
2400+ auto idx = rewrite.elementIndex_ .value ();
2401+ auto element = parameterPack->elements ()[idx];
2402+ copy->symbol = element;
2403+ copy->type = element->type ();
2404+ }
23682405 }
23692406
23702407 copy->isTemplateIntroduced = ast->isTemplateIntroduced ;
@@ -2473,6 +2510,33 @@ auto ASTRewriter::ExpressionVisitor::operator()(RightFoldExpressionAST* ast)
24732510
24742511auto ASTRewriter::ExpressionVisitor::operator ()(LeftFoldExpressionAST* ast)
24752512 -> ExpressionAST* {
2513+ if (auto parameterPack = rewrite.getParameterPack (ast->expression )) {
2514+ std::vector<ExpressionAST*> instantiations;
2515+ ExpressionAST* current = nullptr ;
2516+ int n = 0 ;
2517+ for (auto element : parameterPack->elements ()) {
2518+ std::optional<int > index{n};
2519+ std::swap (rewrite.elementIndex_ , index);
2520+ auto expression = rewrite (ast->expression );
2521+ if (!current) {
2522+ current = expression;
2523+ } else {
2524+ auto binop = make_node<BinaryExpressionAST>(arena ());
2525+ binop->valueCategory = current->valueCategory ;
2526+ binop->type = current->type ;
2527+ binop->leftExpression = current;
2528+ binop->op = ast->op ;
2529+ binop->opLoc = ast->opLoc ;
2530+ binop->rightExpression = expression;
2531+ current = binop;
2532+ }
2533+ std::swap (rewrite.elementIndex_ , index);
2534+ ++n;
2535+ }
2536+
2537+ return current;
2538+ }
2539+
24762540 auto copy = make_node<LeftFoldExpressionAST>(arena ());
24772541
24782542 copy->valueCategory = ast->valueCategory ;
0 commit comments