Skip to content

Commit 2d78c28

Browse files
committed
Integrate AST rewriter
1 parent c9224f1 commit 2d78c28

40 files changed

+5848
-6364
lines changed

packages/cxx-gen-ast/src/new_ast_rewriter_cc.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ if (auto param = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
359359
std::optional<int> index{n};
360360
std::swap(rewrite.elementIndex_, index);
361361
362-
auto expression = rewrite(ast->expression);
362+
auto expression = rewrite.expression(ast->expression);
363363
if (!current) {
364364
current = expression;
365365
} else {

src/mlir/cxx/mlir/codegen.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
109109
std::vector<mlir::Type> inputTypes;
110110
std::vector<mlir::Type> resultTypes;
111111

112+
if (!functionSymbol->isStatic() &&
113+
functionSymbol->enclosingSymbol()->isClass()) {
114+
// if it is a non static member function, we need to add the `this` pointer
115+
116+
auto classSymbol =
117+
symbol_cast<ClassSymbol>(functionSymbol->enclosingSymbol());
118+
119+
inputTypes.push_back(builder_.getType<mlir::cxx::PointerType>(
120+
convertType(classSymbol->type())));
121+
}
122+
112123
for (auto paramTy : functionType->parameterTypes()) {
113124
inputTypes.push_back(convertType(paramTy));
114125
}

src/mlir/cxx/mlir/codegen.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ class Codegen {
314314
mlir::ModuleOp module_;
315315
mlir::cxx::FuncOp function_;
316316
TranslationUnit* unit_ = nullptr;
317+
mlir::Block* entryBlock_ = nullptr;
317318
mlir::Block* exitBlock_ = nullptr;
318319
mlir::cxx::AllocaOp exitValue_;
320+
mlir::Value thisValue_;
319321
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
320322
std::unordered_map<Symbol*, mlir::Value> locals_;
321323
std::unordered_map<FunctionSymbol*, mlir::cxx::FuncOp> funcOps_;

src/mlir/cxx/mlir/codegen_declarations.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
348348

349349
// Add the function body.
350350
auto entryBlock = gen.builder_.createBlock(&func.getBody());
351-
for (const auto& input : func.getFunctionType().getInputs()) {
351+
auto inputs = func.getFunctionType().getInputs();
352+
353+
for (const auto& input : inputs) {
352354
entryBlock->addArgument(input, loc);
353355
}
354356

@@ -370,17 +372,36 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
370372

371373
// function state
372374
std::swap(gen.function_, func);
375+
std::swap(gen.entryBlock_, entryBlock);
373376
std::swap(gen.exitBlock_, exitBlock);
374377
std::swap(gen.exitValue_, exitValue);
375378
std::swap(gen.locals_, locals);
376379

380+
mlir::Value thisValue;
381+
382+
// if this is a non static member function, we need to allocate the `this`
383+
if (!functionSymbol->isStatic() &&
384+
functionSymbol->enclosingSymbol()->isClass()) {
385+
auto classSymbol =
386+
symbol_cast<ClassSymbol>(functionSymbol->enclosingSymbol());
387+
388+
auto thisType = gen.convertType(classSymbol->type());
389+
auto ptrType = gen.builder_.getType<mlir::cxx::PointerType>(thisType);
390+
391+
thisValue = gen.newTemp(classSymbol->type(), ast->firstSourceLocation());
392+
393+
// store the `this` pointer in the entry block
394+
gen.builder_.create<mlir::cxx::StoreOp>(
395+
loc, gen.entryBlock_->getArgument(0), thisValue);
396+
}
397+
377398
FunctionParametersSymbol* params = nullptr;
378399
for (auto member : ast->symbol->scope()->symbols()) {
379400
params = symbol_cast<FunctionParametersSymbol>(member);
380401
if (!params) continue;
381402

382403
int argc = 0;
383-
auto args = entryBlock->getArguments();
404+
auto args = gen.entryBlock_->getArguments();
384405
for (auto param : params->scope()->symbols()) {
385406
auto arg = symbol_cast<ParameterSymbol>(param);
386407
if (!arg) continue;
@@ -399,6 +420,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
399420
}
400421
}
401422

423+
std::swap(gen.thisValue_, thisValue);
424+
402425
allocateLocals(functionSymbol);
403426

404427
// generate code for the function body
@@ -428,7 +451,10 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
428451
}
429452

430453
// restore the state
454+
std::swap(gen.thisValue_, thisValue);
455+
431456
std::swap(gen.function_, func);
457+
std::swap(gen.entryBlock_, entryBlock);
432458
std::swap(gen.exitBlock_, exitBlock);
433459
std::swap(gen.exitValue_, exitValue);
434460
std::swap(gen.locals_, locals);

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,15 @@ auto Codegen::ExpressionVisitor::operator()(ObjectLiteralExpressionAST* ast)
299299

300300
auto Codegen::ExpressionVisitor::operator()(ThisExpressionAST* ast)
301301
-> ExpressionResult {
302-
auto op =
303-
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
304-
return {op};
302+
auto type = gen.convertType(ast->type);
303+
auto loc = gen.getLocation(ast->firstSourceLocation());
304+
305+
auto loadOp =
306+
gen.builder_.create<mlir::cxx::LoadOp>(loc, type, gen.thisValue_);
307+
return {loadOp};
308+
// auto op =
309+
// gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
310+
// return {op};
305311
}
306312

307313
auto Codegen::ExpressionVisitor::operator()(GenericSelectionExpressionAST* ast)
@@ -499,6 +505,38 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
499505
func = nested->expression;
500506
}
501507

508+
if (auto member = ast_cast<MemberExpressionAST>(func)) {
509+
auto thisValue = gen.expression(member->baseExpression);
510+
auto functionSymbol = symbol_cast<FunctionSymbol>(member->symbol);
511+
512+
auto funcOp = gen.findOrCreateFunction(functionSymbol);
513+
514+
mlir::SmallVector<mlir::Value> arguments;
515+
arguments.push_back(thisValue.value);
516+
for (auto node : ListView{ast->expressionList}) {
517+
auto value = gen.expression(node);
518+
arguments.push_back(value.value);
519+
}
520+
521+
auto loc = gen.getLocation(ast->lparenLoc);
522+
523+
auto functionType = type_cast<FunctionType>(functionSymbol->type());
524+
mlir::SmallVector<mlir::Type> resultTypes;
525+
if (!control()->is_void(functionType->returnType())) {
526+
resultTypes.push_back(gen.convertType(functionType->returnType()));
527+
}
528+
529+
auto op = gen.builder_.create<mlir::cxx::CallOp>(
530+
loc, resultTypes, funcOp.getSymName(), arguments, mlir::TypeAttr{});
531+
532+
if (functionType->isVariadic()) {
533+
op.setVarCalleeType(
534+
cast<mlir::cxx::FunctionType>(gen.convertType(functionType)));
535+
}
536+
537+
return ExpressionResult{op.getResult()};
538+
}
539+
502540
auto id = ast_cast<IdExpressionAST>(func);
503541
if (!id) return {};
504542

src/mlir/cxx/mlir/codegen_units.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <cxx/ast_visitor.h>
2626
#include <cxx/control.h>
2727
#include <cxx/memory_layout.h>
28+
#include <cxx/scope.h>
29+
#include <cxx/symbols.h>
2830
#include <cxx/translation_unit.h>
2931

3032
// mlir
@@ -62,6 +64,38 @@ struct Codegen::UnitVisitor {
6264

6365
auto operator()(TranslationUnitAST* ast) -> UnitResult;
6466
auto operator()(ModuleUnitAST* ast) -> UnitResult;
67+
68+
struct VisitSymbols {
69+
UnitVisitor& p;
70+
71+
void operator()(NamespaceSymbol* symbol) {
72+
for (auto member : symbol->scope()->symbols()) {
73+
visit(*this, member);
74+
}
75+
}
76+
77+
void operator()(FunctionSymbol* symbol) {
78+
if (auto funcDecl = symbol->declaration()) {
79+
p.gen.declaration(funcDecl);
80+
}
81+
}
82+
83+
void operator()(ClassSymbol* symbol) {
84+
for (auto specialization : symbol->specializations()) {
85+
visit(*this, specialization.symbol);
86+
}
87+
88+
if (!symbol->templateParameters()) {
89+
for (auto member : symbol->scope()->symbols()) {
90+
visit(*this, member);
91+
}
92+
}
93+
}
94+
95+
void operator()(Symbol*) {
96+
// Do nothing for other symbols.
97+
}
98+
} visitor{*this};
6599
};
66100

67101
auto Codegen::operator()(UnitAST* ast) -> UnitResult {
@@ -115,6 +149,9 @@ auto Codegen::UnitVisitor::operator()(TranslationUnitAST* ast) -> UnitResult {
115149

116150
std::swap(gen.module_, module);
117151

152+
visit(visitor, gen.unit_->globalScope()->owner());
153+
154+
#if false
118155
ForEachExternalDefinition forEachExternalDefinition;
119156

120157
forEachExternalDefinition.functionCallback =
@@ -125,6 +162,7 @@ auto Codegen::UnitVisitor::operator()(TranslationUnitAST* ast) -> UnitResult {
125162
for (auto node : ListView{ast->declarationList}) {
126163
forEachExternalDefinition.accept(node);
127164
}
165+
#endif
128166

129167
std::swap(gen.module_, module);
130168

src/parser/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ add_library(cxx-parser
2525
cxx/ast_interpreter.cc
2626
cxx/ast_pretty_printer.cc
2727
cxx/ast_printer.cc
28+
cxx/ast_rewriter_declarations.cc
29+
cxx/ast_rewriter_declarators.cc
30+
cxx/ast_rewriter_expressions.cc
31+
cxx/ast_rewriter_names.cc
32+
cxx/ast_rewriter_specifiers.cc
33+
cxx/ast_rewriter_statements.cc
34+
cxx/ast_rewriter_units.cc
2835
cxx/ast_rewriter.cc
2936
cxx/ast_slot.cc
3037
cxx/ast_visitor.cc
@@ -54,7 +61,6 @@ add_library(cxx-parser
5461
cxx/scope.cc
5562
cxx/source_location.cc
5663
cxx/symbol_chain_view.cc
57-
cxx/symbol_instantiation.cc
5864
cxx/symbol_printer.cc
5965
cxx/symbols.cc
6066
cxx/token.cc

src/parser/cxx/ast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3703,7 +3703,7 @@ class SimpleTemplateIdAST final : public UnqualifiedIdAST {
37033703
List<TemplateArgumentAST*>* templateArgumentList = nullptr;
37043704
SourceLocation greaterLoc;
37053705
const Identifier* identifier = nullptr;
3706-
Symbol* primaryTemplateSymbol = nullptr;
3706+
Symbol* symbol = nullptr;
37073707

37083708
void accept(ASTVisitor* visitor) override { visitor->visit(this); }
37093709

0 commit comments

Comments
 (0)