@@ -62,17 +62,17 @@ export function new_ast_rewriter_cc({
6262 emit ( ` struct ASTRewriter::${ className } Visitor {` ) ;
6363 emit ( ` ASTRewriter& rewrite;` ) ;
6464 emit (
65- `[[nodiscard]] auto translationUnit() const -> TranslationUnit* { return rewrite.unit_; }` ,
65+ `[[nodiscard]] auto translationUnit() const -> TranslationUnit* { return rewrite.unit_; }`
6666 ) ;
6767 emit ( ) ;
6868 emit (
69- `[[nodiscard]] auto control() const -> Control* { return rewrite.control(); }` ,
69+ `[[nodiscard]] auto control() const -> Control* { return rewrite.control(); }`
7070 ) ;
7171 emit (
72- `[[nodiscard]] auto arena() const -> Arena* { return rewrite.arena(); }` ,
72+ `[[nodiscard]] auto arena() const -> Arena* { return rewrite.arena(); }`
7373 ) ;
7474 emit (
75- `[[nodiscard]] auto rewriter() const -> ASTRewriter* { return &rewrite; }` ,
75+ `[[nodiscard]] auto rewriter() const -> ASTRewriter* { return &rewrite; }`
7676 ) ;
7777 nodes . forEach ( ( { name } ) => {
7878 emit ( ) ;
@@ -81,9 +81,17 @@ export function new_ast_rewriter_cc({
8181 emit ( ` };` ) ;
8282 } ) ;
8383
84- const emitRewriterBody = ( members : Member [ ] , visitor : string = "rewrite" ) => {
84+ const emitRewriterBody = ( {
85+ name,
86+ members,
87+ visitor = "rewrite" ,
88+ } : {
89+ name : string ;
90+ members : Member [ ] ;
91+ visitor ?: string ;
92+ } ) => {
8593 const blockSymbol = members . find (
86- ( m ) => m . kind === "attribute" && m . type === "BlockSymbol" ,
94+ ( m ) => m . kind === "attribute" && m . type === "BlockSymbol"
8795 ) ;
8896
8997 if ( blockSymbol ) {
@@ -114,11 +122,11 @@ export function new_ast_rewriter_cc({
114122 if ( specsAttr ) {
115123 emit ( ) ;
116124 emit (
117- `auto ${ m . name } Type = getDeclaratorType(translationUnit(), copy->${ m . name } , ${ specsAttr } Ctx.getType());` ,
125+ `auto ${ m . name } Type = getDeclaratorType(translationUnit(), copy->${ m . name } , ${ specsAttr } Ctx.getType());`
118126 ) ;
119127
120128 typeAttr = members . find (
121- ( m ) => m . kind === "attribute" && m . name === "type" ,
129+ ( m ) => m . kind === "attribute" && m . name === "type"
122130 ) ;
123131
124132 if ( typeAttr ) {
@@ -131,7 +139,7 @@ export function new_ast_rewriter_cc({
131139 } // switch
132140 } else {
133141 emit (
134- `copy->${ m . name } = ast_cast<${ m . type } >(${ visitor } (ast->${ m . name } ));` ,
142+ `copy->${ m . name } = ast_cast<${ m . type } >(${ visitor } (ast->${ m . name } ));`
135143 ) ;
136144 }
137145 break ;
@@ -150,7 +158,7 @@ export function new_ast_rewriter_cc({
150158 } // switch
151159
152160 emit (
153- `for (auto ${ m . name } = ©->${ m . name } ; auto node : ListView{ast->${ m . name } }) {` ,
161+ `for (auto ${ m . name } = ©->${ m . name } ; auto node : ListView{ast->${ m . name } }) {`
154162 ) ;
155163
156164 switch ( m . type ) {
@@ -167,7 +175,7 @@ export function new_ast_rewriter_cc({
167175 emit ( `*${ m . name } = make_list_node(arena(), value);` ) ;
168176 } else {
169177 emit (
170- `*${ m . name } = make_list_node(arena(), ast_cast<${ m . type } >(value));` ,
178+ `*${ m . name } = make_list_node(arena(), ast_cast<${ m . type } >(value));`
171179 ) ;
172180 }
173181 emit ( `${ m . name } = &(*${ m . name } )->next;` ) ;
@@ -188,6 +196,39 @@ export function new_ast_rewriter_cc({
188196 }
189197 case "attribute" : {
190198 emit ( ` copy->${ m . name } = ast->${ m . name } ;` ) ;
199+
200+ if ( m . name == "symbol" && name == "NamedTypeSpecifierAST" ) {
201+ emit ( `
202+ if (auto typeParameter = symbol_cast<TypeParameterSymbol>(copy->symbol)) {
203+ const auto& args = rewrite.templateArguments_;
204+ if (typeParameter && typeParameter->depth() == 0 &&
205+ typeParameter->index() < args.size()) {
206+ auto index = typeParameter->index();
207+
208+ if (auto sym = std::get_if<Symbol*>(&args[index])) {
209+ copy->symbol = *sym;
210+ }
211+ }
212+ }
213+ ` ) ;
214+ }
215+
216+ if ( m . name == "symbol" && name == "IdExpressionAST" ) {
217+ emit ( `
218+ if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
219+ param && param->depth() == 0 &&
220+ param->index() < rewrite.templateArguments_.size()) {
221+ auto symbolPtr =
222+ std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
223+
224+ if (!symbolPtr) {
225+ cxx_runtime_error("expected initializer for non-type template parameter");
226+ }
227+
228+ copy->symbol = *symbolPtr;
229+ copy->type = copy->symbol->type();
230+ }` ) ;
231+ }
191232 break ;
192233 }
193234 case "token" : {
@@ -234,7 +275,7 @@ export function new_ast_rewriter_cc({
234275 switch ( name ) {
235276 case "InitDeclaratorAST" :
236277 emit (
237- `auto ASTRewriter::operator()(${ name } * ast, const DeclSpecs& declSpecs) -> ${ name } * {` ,
278+ `auto ASTRewriter::operator()(${ name } * ast, const DeclSpecs& declSpecs) -> ${ name } * {`
238279 ) ;
239280 break ;
240281 default :
@@ -246,7 +287,7 @@ export function new_ast_rewriter_cc({
246287 emit ( ) ;
247288 emit ( ` auto copy = make_node<${ name } >(arena());` ) ;
248289 emit ( ) ;
249- emitRewriterBody ( members , "operator()" ) ;
290+ emitRewriterBody ( { name , members, visitor : "operator()" } ) ;
250291 emit ( ) ;
251292 emit ( ` return copy;` ) ;
252293 emit ( `}` ) ;
@@ -259,38 +300,79 @@ export function new_ast_rewriter_cc({
259300 nodes . forEach ( ( { name, members } ) => {
260301 emit ( ) ;
261302 emit (
262- `auto ASTRewriter::${ className } Visitor::operator()(${ name } * ast) -> ${ base } * {` ,
303+ `auto ASTRewriter::${ className } Visitor::operator()(${ name } * ast) -> ${ base } * {`
263304 ) ;
305+
264306 if ( name === "IdExpressionAST" ) {
265307 emit ( `
266- if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
267- x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
268- auto initializerPtr =
269- std::get_if<ExpressionAST*>(&rewrite.templateArguments_[x->index()]);
270- if (!initializerPtr) {
308+ if (auto param = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
309+ param && param->depth() == 0 &&
310+ param->index() < rewrite.templateArguments_.size()) {
311+ auto symbolPtr =
312+ std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
313+
314+ if (!symbolPtr) {
271315 cxx_runtime_error("expected initializer for non-type template parameter");
272316 }
273317
274- auto initializer = rewrite(*initializerPtr );
318+ auto parameterPack = symbol_cast<ParameterPackSymbol>(*symbolPtr );
275319
276- if (auto eq = ast_cast<EqualInitializerAST>(initializer)) {
277- return eq->expression;
320+ if (parameterPack && parameterPack == rewrite.parameterPack_ &&
321+ rewrite.elementIndex_.has_value()) {
322+ auto idx = rewrite.elementIndex_.value();
323+ auto element = parameterPack->elements()[idx];
324+ if (auto var = symbol_cast<VariableSymbol>(element)) {
325+ return rewrite(var->initializer());
326+ }
278327 }
328+ }
329+ ` ) ;
330+ }
331+ if ( name === "LeftFoldExpressionAST" ) {
332+ emit ( `
333+ if (auto parameterPack = rewrite.getParameterPack(ast->expression)) {
334+ auto savedParameterPack = rewrite.parameterPack_;
335+ std::swap(rewrite.parameterPack_, parameterPack);
336+
337+ std::vector<ExpressionAST*> instantiations;
338+ ExpressionAST* current = nullptr;
339+
340+ int n = 0;
341+ for (auto element : rewrite.parameterPack_->elements()) {
342+ std::optional<int> index{n};
343+ std::swap(rewrite.elementIndex_, index);
344+
345+ auto expression = rewrite(ast->expression);
346+ if (!current) {
347+ current = expression;
348+ } else {
349+ auto binop = make_node<BinaryExpressionAST>(arena());
350+ binop->valueCategory = current->valueCategory;
351+ binop->type = current->type;
352+ binop->leftExpression = current;
353+ binop->op = ast->op;
354+ binop->opLoc = ast->opLoc;
355+ binop->rightExpression = expression;
356+ current = binop;
357+ }
279358
280- if (auto bracedInit = ast_cast<BracedInitListAST>(initializer)) {
281- if (bracedInit->expressionList && !bracedInit->expressionList->next) {
282- return bracedInit->expressionList->value;
359+ std::swap(rewrite.elementIndex_, index);
360+ ++n;
283361 }
362+
363+ std::swap(rewrite.parameterPack_, parameterPack);
364+
365+ return current;
284366 }
285- }
286367` ) ;
287368 }
369+
288370 emit ( ` auto copy = make_node<${ name } >(arena());` ) ;
289371 emit ( ) ;
290372 ast . baseMembers . get ( base ) ?. forEach ( ( m ) => {
291373 emit ( ` copy->${ m . name } = ast->${ m . name } ;` ) ;
292374 } ) ;
293- emitRewriterBody ( members , "rewrite" ) ;
375+ emitRewriterBody ( { name , members, visitor : "rewrite" } ) ;
294376 emit ( ) ;
295377 emit ( ` return copy;` ) ;
296378 emit ( `}` ) ;
@@ -311,6 +393,9 @@ if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
311393#include <cxx/symbols.h>
312394#include <cxx/types.h>
313395#include <cxx/binder.h>
396+ #include <cxx/ast_cursor.h>
397+
398+ #include <format>
314399
315400namespace cxx {
316401
@@ -339,6 +424,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
339424 restrictedToDeclarations_ = restrictedToDeclarations;
340425}
341426
427+ auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
428+ for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
429+ const auto& current = *cursor;
430+ if (!std::holds_alternative<AST*>(current.node)) continue;
431+
432+ auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node));
433+ if (!id) continue;
434+
435+ auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol);
436+ if (!param) continue;
437+
438+ if (param->depth() != 0) continue;
439+
440+ auto arg = templateArguments_[param->index()];
441+ auto argSymbol = std::get<Symbol*>(arg);
442+
443+ auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
444+ if (parameterPack) return parameterPack;
445+ }
446+
447+ return nullptr;
448+ }
449+
342450${ code . join ( "\n" ) }
343451
344452} // namespace cxx
0 commit comments