Skip to content

Commit 264db70

Browse files
committed
Simulate template argument pack expansion
1 parent 8e012c8 commit 264db70

File tree

13 files changed

+374
-39
lines changed

13 files changed

+374
-39
lines changed

packages/cxx-gen-ast/src/new_ast_rewriter_cc.ts

Lines changed: 135 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,17 @@ export function new_ast_rewriter_cc({
6262
emit(` struct ASTRewriter::${className}Visitor {`);
6363
emit(` ASTRewriter& rewrite;`);
6464
emit(
65-
`[[nodiscard]] auto translationUnit() const -> TranslationUnit* { return rewrite.unit_; }`,
65+
`[[nodiscard]] auto translationUnit() const -> TranslationUnit* { return rewrite.unit_; }`
6666
);
6767
emit();
6868
emit(
69-
`[[nodiscard]] auto control() const -> Control* { return rewrite.control(); }`,
69+
`[[nodiscard]] auto control() const -> Control* { return rewrite.control(); }`
7070
);
7171
emit(
72-
`[[nodiscard]] auto arena() const -> Arena* { return rewrite.arena(); }`,
72+
`[[nodiscard]] auto arena() const -> Arena* { return rewrite.arena(); }`
7373
);
7474
emit(
75-
`[[nodiscard]] auto rewriter() const -> ASTRewriter* { return &rewrite; }`,
75+
`[[nodiscard]] auto rewriter() const -> ASTRewriter* { return &rewrite; }`
7676
);
7777
nodes.forEach(({ name }) => {
7878
emit();
@@ -81,9 +81,17 @@ export function new_ast_rewriter_cc({
8181
emit(` };`);
8282
});
8383

84-
const emitRewriterBody = (members: Member[], visitor: string = "rewrite") => {
84+
const emitRewriterBody = ({
85+
name,
86+
members,
87+
visitor = "rewrite",
88+
}: {
89+
name: string;
90+
members: Member[];
91+
visitor?: string;
92+
}) => {
8593
const blockSymbol = members.find(
86-
(m) => m.kind === "attribute" && m.type === "BlockSymbol",
94+
(m) => m.kind === "attribute" && m.type === "BlockSymbol"
8795
);
8896

8997
if (blockSymbol) {
@@ -114,11 +122,11 @@ export function new_ast_rewriter_cc({
114122
if (specsAttr) {
115123
emit();
116124
emit(
117-
`auto ${m.name}Type = getDeclaratorType(translationUnit(), copy->${m.name}, ${specsAttr}Ctx.getType());`,
125+
`auto ${m.name}Type = getDeclaratorType(translationUnit(), copy->${m.name}, ${specsAttr}Ctx.getType());`
118126
);
119127

120128
typeAttr = members.find(
121-
(m) => m.kind === "attribute" && m.name === "type",
129+
(m) => m.kind === "attribute" && m.name === "type"
122130
);
123131

124132
if (typeAttr) {
@@ -131,7 +139,7 @@ export function new_ast_rewriter_cc({
131139
} // switch
132140
} else {
133141
emit(
134-
`copy->${m.name} = ast_cast<${m.type}>(${visitor}(ast->${m.name}));`,
142+
`copy->${m.name} = ast_cast<${m.type}>(${visitor}(ast->${m.name}));`
135143
);
136144
}
137145
break;
@@ -150,7 +158,7 @@ export function new_ast_rewriter_cc({
150158
} // switch
151159

152160
emit(
153-
`for (auto ${m.name} = &copy->${m.name}; auto node : ListView{ast->${m.name}}) {`,
161+
`for (auto ${m.name} = &copy->${m.name}; auto node : ListView{ast->${m.name}}) {`
154162
);
155163

156164
switch (m.type) {
@@ -167,7 +175,7 @@ export function new_ast_rewriter_cc({
167175
emit(`*${m.name} = make_list_node(arena(), value);`);
168176
} else {
169177
emit(
170-
`*${m.name} = make_list_node(arena(), ast_cast<${m.type}>(value));`,
178+
`*${m.name} = make_list_node(arena(), ast_cast<${m.type}>(value));`
171179
);
172180
}
173181
emit(`${m.name} = &(*${m.name})->next;`);
@@ -188,6 +196,39 @@ export function new_ast_rewriter_cc({
188196
}
189197
case "attribute": {
190198
emit(` copy->${m.name} = ast->${m.name};`);
199+
200+
if (m.name == "symbol" && name == "NamedTypeSpecifierAST") {
201+
emit(`
202+
if (auto typeParameter = symbol_cast<TypeParameterSymbol>(copy->symbol)) {
203+
const auto& args = rewrite.templateArguments_;
204+
if (typeParameter && typeParameter->depth() == 0 &&
205+
typeParameter->index() < args.size()) {
206+
auto index = typeParameter->index();
207+
208+
if (auto sym = std::get_if<Symbol*>(&args[index])) {
209+
copy->symbol = *sym;
210+
}
211+
}
212+
}
213+
`);
214+
}
215+
216+
if (m.name == "symbol" && name == "IdExpressionAST") {
217+
emit(`
218+
if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
219+
param && param->depth() == 0 &&
220+
param->index() < rewrite.templateArguments_.size()) {
221+
auto symbolPtr =
222+
std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
223+
224+
if (!symbolPtr) {
225+
cxx_runtime_error("expected initializer for non-type template parameter");
226+
}
227+
228+
copy->symbol = *symbolPtr;
229+
copy->type = copy->symbol->type();
230+
}`);
231+
}
191232
break;
192233
}
193234
case "token": {
@@ -234,7 +275,7 @@ export function new_ast_rewriter_cc({
234275
switch (name) {
235276
case "InitDeclaratorAST":
236277
emit(
237-
`auto ASTRewriter::operator()(${name}* ast, const DeclSpecs& declSpecs) -> ${name}* {`,
278+
`auto ASTRewriter::operator()(${name}* ast, const DeclSpecs& declSpecs) -> ${name}* {`
238279
);
239280
break;
240281
default:
@@ -246,7 +287,7 @@ export function new_ast_rewriter_cc({
246287
emit();
247288
emit(` auto copy = make_node<${name}>(arena());`);
248289
emit();
249-
emitRewriterBody(members, "operator()");
290+
emitRewriterBody({ name, members, visitor: "operator()" });
250291
emit();
251292
emit(` return copy;`);
252293
emit(`}`);
@@ -259,38 +300,79 @@ export function new_ast_rewriter_cc({
259300
nodes.forEach(({ name, members }) => {
260301
emit();
261302
emit(
262-
`auto ASTRewriter::${className}Visitor::operator()(${name}* ast) -> ${base}* {`,
303+
`auto ASTRewriter::${className}Visitor::operator()(${name}* ast) -> ${base}* {`
263304
);
305+
264306
if (name === "IdExpressionAST") {
265307
emit(`
266-
if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
267-
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
268-
auto initializerPtr =
269-
std::get_if<ExpressionAST*>(&rewrite.templateArguments_[x->index()]);
270-
if (!initializerPtr) {
308+
if (auto param = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
309+
param && param->depth() == 0 &&
310+
param->index() < rewrite.templateArguments_.size()) {
311+
auto symbolPtr =
312+
std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
313+
314+
if (!symbolPtr) {
271315
cxx_runtime_error("expected initializer for non-type template parameter");
272316
}
273317
274-
auto initializer = rewrite(*initializerPtr);
318+
auto parameterPack = symbol_cast<ParameterPackSymbol>(*symbolPtr);
275319
276-
if (auto eq = ast_cast<EqualInitializerAST>(initializer)) {
277-
return eq->expression;
320+
if (parameterPack && parameterPack == rewrite.parameterPack_ &&
321+
rewrite.elementIndex_.has_value()) {
322+
auto idx = rewrite.elementIndex_.value();
323+
auto element = parameterPack->elements()[idx];
324+
if (auto var = symbol_cast<VariableSymbol>(element)) {
325+
return rewrite(var->initializer());
326+
}
278327
}
328+
}
329+
`);
330+
}
331+
if (name === "LeftFoldExpressionAST") {
332+
emit(`
333+
if (auto parameterPack = rewrite.getParameterPack(ast->expression)) {
334+
auto savedParameterPack = rewrite.parameterPack_;
335+
std::swap(rewrite.parameterPack_, parameterPack);
336+
337+
std::vector<ExpressionAST*> instantiations;
338+
ExpressionAST* current = nullptr;
339+
340+
int n = 0;
341+
for (auto element : rewrite.parameterPack_->elements()) {
342+
std::optional<int> index{n};
343+
std::swap(rewrite.elementIndex_, index);
344+
345+
auto expression = rewrite(ast->expression);
346+
if (!current) {
347+
current = expression;
348+
} else {
349+
auto binop = make_node<BinaryExpressionAST>(arena());
350+
binop->valueCategory = current->valueCategory;
351+
binop->type = current->type;
352+
binop->leftExpression = current;
353+
binop->op = ast->op;
354+
binop->opLoc = ast->opLoc;
355+
binop->rightExpression = expression;
356+
current = binop;
357+
}
279358
280-
if (auto bracedInit = ast_cast<BracedInitListAST>(initializer)) {
281-
if (bracedInit->expressionList && !bracedInit->expressionList->next) {
282-
return bracedInit->expressionList->value;
359+
std::swap(rewrite.elementIndex_, index);
360+
++n;
283361
}
362+
363+
std::swap(rewrite.parameterPack_, parameterPack);
364+
365+
return current;
284366
}
285-
}
286367
`);
287368
}
369+
288370
emit(` auto copy = make_node<${name}>(arena());`);
289371
emit();
290372
ast.baseMembers.get(base)?.forEach((m) => {
291373
emit(` copy->${m.name} = ast->${m.name};`);
292374
});
293-
emitRewriterBody(members, "rewrite");
375+
emitRewriterBody({ name, members, visitor: "rewrite" });
294376
emit();
295377
emit(` return copy;`);
296378
emit(`}`);
@@ -311,6 +393,9 @@ if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
311393
#include <cxx/symbols.h>
312394
#include <cxx/types.h>
313395
#include <cxx/binder.h>
396+
#include <cxx/ast_cursor.h>
397+
398+
#include <format>
314399
315400
namespace cxx {
316401
@@ -339,6 +424,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
339424
restrictedToDeclarations_ = restrictedToDeclarations;
340425
}
341426
427+
auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
428+
for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
429+
const auto& current = *cursor;
430+
if (!std::holds_alternative<AST*>(current.node)) continue;
431+
432+
auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node));
433+
if (!id) continue;
434+
435+
auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol);
436+
if (!param) continue;
437+
438+
if (param->depth() != 0) continue;
439+
440+
auto arg = templateArguments_[param->index()];
441+
auto argSymbol = std::get<Symbol*>(arg);
442+
443+
auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
444+
if (parameterPack) return parameterPack;
445+
}
446+
447+
return nullptr;
448+
}
449+
342450
${code.join("\n")}
343451
344452
} // namespace cxx

packages/cxx-gen-ast/src/new_ast_rewriter_h.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export function new_ast_rewriter_h({
5454
switch (name) {
5555
case "InitDeclaratorAST":
5656
emit(
57-
` [[nodiscard]] auto operator()(${name}* ast, const DeclSpecs& declSpecs) -> ${name}*;`,
57+
` [[nodiscard]] auto operator()(${name}* ast, const DeclSpecs& declSpecs) -> ${name}*;`
5858
);
5959
break;
6060
default:
@@ -117,8 +117,12 @@ ${code.join("\n")}
117117
private:
118118
[[nodiscard]] auto rewriter() -> ASTRewriter* { return this; }
119119
120+
[[nodiscard]] auto getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol*;
121+
120122
TypeChecker* typeChecker_ = nullptr;
121123
const std::vector<TemplateArgument>& templateArguments_;
124+
ParameterPackSymbol* parameterPack_ = nullptr;
125+
std::optional<int> elementIndex_;
122126
TranslationUnit* unit_ = nullptr;
123127
Binder binder_;
124128
bool restrictedToDeclarations_ = false;

0 commit comments

Comments
 (0)