Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/frontend/cxx/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include <cxx/memory_layout.h>
#include <cxx/preprocessor.h>
#include <cxx/private/path.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
Expand Down Expand Up @@ -439,7 +438,7 @@ void Frontend::Private::dumpTokens(std::ostream& out) {
void Frontend::Private::dumpSymbols(std::ostream& out) {
if (!cli.opt_dump_symbols) return;
auto globalScope = unit_->globalScope();
auto globalNamespace = globalScope->owner();
auto globalNamespace = globalScope;
cxx::dump(out, globalNamespace);
}

Expand Down
2 changes: 1 addition & 1 deletion src/lsp/cxx/lsp/cxx_document.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
#include <cxx/macos_toolchain.h>
#include <cxx/preprocessor.h>
#include <cxx/private/path.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
#include <cxx/views/symbols.h>
#include <cxx/wasm32_wasi_toolchain.h>
#include <cxx/windows_toolchain.h>

Expand Down
6 changes: 2 additions & 4 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
std::vector<mlir::Type> inputTypes;
std::vector<mlir::Type> resultTypes;

if (!functionSymbol->isStatic() &&
functionSymbol->enclosingSymbol()->isClass()) {
if (!functionSymbol->isStatic() && functionSymbol->parent()->isClass()) {
// if it is a non static member function, we need to add the `this` pointer

auto classSymbol =
symbol_cast<ClassSymbol>(functionSymbol->enclosingSymbol());
auto classSymbol = symbol_cast<ClassSymbol>(functionSymbol->parent());

inputTypes.push_back(builder_.getType<mlir::cxx::PointerType>(
convertType(classSymbol->type())));
Expand Down
18 changes: 8 additions & 10 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
#include <cxx/ast.h>
#include <cxx/control.h>
#include <cxx/external_name_encoder.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
#include <cxx/views/symbols.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
Expand All @@ -40,7 +40,7 @@ namespace cxx {
struct Codegen::DeclarationVisitor {
Codegen& gen;

void allocateLocals(ScopedSymbol* block);
void allocateLocals(ScopeSymbol* block);

auto operator()(SimpleDeclarationAST* ast) -> DeclarationResult;
auto operator()(AsmDeclarationAST* ast) -> DeclarationResult;
Expand Down Expand Up @@ -142,8 +142,8 @@ auto Codegen::lambdaSpecifier(LambdaSpecifierAST* ast)
return {};
}

void Codegen::DeclarationVisitor::allocateLocals(ScopedSymbol* block) {
for (auto symbol : block->scope()->symbols()) {
void Codegen::DeclarationVisitor::allocateLocals(ScopeSymbol* block) {
for (auto symbol : views::members(block)) {
if (auto nestedBlock = symbol_cast<BlockSymbol>(symbol)) {
allocateLocals(nestedBlock);
continue;
Expand Down Expand Up @@ -380,10 +380,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
mlir::Value thisValue;

// if this is a non static member function, we need to allocate the `this`
if (!functionSymbol->isStatic() &&
functionSymbol->enclosingSymbol()->isClass()) {
auto classSymbol =
symbol_cast<ClassSymbol>(functionSymbol->enclosingSymbol());
if (!functionSymbol->isStatic() && functionSymbol->parent()->isClass()) {
auto classSymbol = symbol_cast<ClassSymbol>(functionSymbol->parent());

auto thisType = gen.convertType(classSymbol->type());
auto ptrType = gen.builder_.getType<mlir::cxx::PointerType>(thisType);
Expand All @@ -396,13 +394,13 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
}

FunctionParametersSymbol* params = nullptr;
for (auto member : ast->symbol->scope()->symbols()) {
for (auto member : views::members(ast->symbol)) {
params = symbol_cast<FunctionParametersSymbol>(member);
if (!params) continue;

int argc = 0;
auto args = gen.entryBlock_->getArguments();
for (auto param : params->scope()->symbols()) {
for (auto param : views::members(params)) {
auto arg = symbol_cast<ParameterSymbol>(param);
if (!arg) continue;

Expand Down
6 changes: 3 additions & 3 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#include <cxx/control.h>
#include <cxx/literals.h>
#include <cxx/memory_layout.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
#include <cxx/views/symbols.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
Expand Down Expand Up @@ -645,8 +645,8 @@ auto Codegen::ExpressionVisitor::operator()(MemberExpressionAST* ast)
// todo: introduce ClassLayout to avoid linear searches and support c++
// class layout
int fieldIndex = 0;
auto classSymbol = symbol_cast<ClassSymbol>(field->enclosingSymbol());
for (auto member : classSymbol->scope()->symbols()) {
auto classSymbol = symbol_cast<ClassSymbol>(field->parent());
for (auto member : cxx::views::members(classSymbol)) {
auto f = symbol_cast<FieldSymbol>(member);
if (!f) continue;
if (f->isStatic()) continue;
Expand Down
8 changes: 4 additions & 4 deletions src/mlir/cxx/mlir/codegen_units.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
#include <cxx/ast_visitor.h>
#include <cxx/control.h>
#include <cxx/memory_layout.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/views/symbols.h>

// mlir
#include <llvm/IR/DataLayout.h>
Expand Down Expand Up @@ -69,7 +69,7 @@ struct Codegen::UnitVisitor {
UnitVisitor& p;

void operator()(NamespaceSymbol* symbol) {
for (auto member : symbol->scope()->symbols()) {
for (auto member : views::members(symbol)) {
visit(*this, member);
}
}
Expand All @@ -86,7 +86,7 @@ struct Codegen::UnitVisitor {
}

if (!symbol->templateParameters()) {
for (auto member : symbol->scope()->symbols()) {
for (auto member : views::members(symbol)) {
visit(*this, member);
}
}
Expand Down Expand Up @@ -149,7 +149,7 @@ auto Codegen::UnitVisitor::operator()(TranslationUnitAST* ast) -> UnitResult {

std::swap(gen.module_, module);

visit(visitor, gen.unit_->globalScope()->owner());
visit(visitor, gen.unit_->globalScope());

#if false
ForEachExternalDefinition forEachExternalDefinition;
Expand Down
4 changes: 2 additions & 2 deletions src/mlir/cxx/mlir/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#include <cxx/literals.h>
#include <cxx/memory_layout.h>
#include <cxx/names.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
#include <cxx/views/symbols.h>

#include <format>

Expand Down Expand Up @@ -294,7 +294,7 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {

std::vector<mlir::Type> memberTypes;

for (auto member : classSymbol->scope()->symbols()) {
for (auto member : views::members(classSymbol)) {
auto field = symbol_cast<FieldSymbol>(member);
if (!field) continue;
if (field->isStatic()) continue;
Expand Down
1 change: 0 additions & 1 deletion src/parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ add_library(cxx-parser
cxx/parser.cc
cxx/path.cc
cxx/preprocessor.cc
cxx/scope.cc
cxx/source_location.cc
cxx/symbol_chain_view.cc
cxx/symbol_printer.cc
Expand Down
2 changes: 1 addition & 1 deletion src/parser/cxx/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class MemInitializerAST : public AST {
class NestedNameSpecifierAST : public AST {
public:
using AST::AST;
ScopedSymbol* symbol = nullptr;
ScopeSymbol* symbol = nullptr;
};

class NewInitializerAST : public AST {
Expand Down
87 changes: 45 additions & 42 deletions src/parser/cxx/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,46 +34,7 @@

namespace cxx {

auto ASTRewriter::make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument> {
auto control = unit->control();
auto interp = ASTInterpreter{unit};

std::vector<TemplateArgument> templateArguments;

for (auto arg : ListView{templateArgumentList}) {
if (auto exprArg = ast_cast<ExpressionTemplateArgumentAST>(arg)) {
auto expr = exprArg->expression;
auto value = interp.evaluate(expr);
if (!value.has_value()) {
#if false
unit->error(arg->firstSourceLocation(),
"template argument is not a constant expression");
#endif
continue;
}

// ### need to set scope and location
auto templArg = control->newVariableSymbol(nullptr, {});
templArg->setInitializer(expr);
templArg->setType(control->add_const(expr->type));
templArg->setConstValue(value);
templateArguments.push_back(templArg);
} else if (auto typeArg = ast_cast<TypeTemplateArgumentAST>(arg)) {
auto type = typeArg->typeId->type;
// ### need to set scope and location
auto templArg = control->newTypeAliasSymbol(nullptr, {});
templArg->setType(type);
templateArguments.push_back(templArg);
}
}

return templateArguments;
}

ASTRewriter::ASTRewriter(TranslationUnit* unit, Scope* scope,
ASTRewriter::ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
const std::vector<TemplateArgument>& templateArguments)
: unit_(unit), templateArguments_(templateArguments), binder_(unit_) {
binder_.setScope(scope);
Expand Down Expand Up @@ -167,7 +128,7 @@ auto ASTRewriter::instantiateClassTemplate(
return subst;
}

auto parentScope = classSymbol->enclosingSymbol()->scope();
auto parentScope = classSymbol->enclosingNonTemplateParametersScope();

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};

Expand Down Expand Up @@ -222,7 +183,10 @@ auto ASTRewriter::instantiateTypeAliasTemplate(
}
#endif

auto parentScope = typeAliasSymbol->enclosingSymbol()->scope();
auto parentScope = typeAliasSymbol->parent();
while (parentScope->isTemplateParameters()) {
parentScope = parentScope->parent();
}

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};

Expand All @@ -236,4 +200,43 @@ auto ASTRewriter::instantiateTypeAliasTemplate(
return instance->symbol;
}

auto ASTRewriter::make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument> {
auto control = unit->control();
auto interp = ASTInterpreter{unit};

std::vector<TemplateArgument> templateArguments;

for (auto arg : ListView{templateArgumentList}) {
if (auto exprArg = ast_cast<ExpressionTemplateArgumentAST>(arg)) {
auto expr = exprArg->expression;
auto value = interp.evaluate(expr);
if (!value.has_value()) {
#if false
unit->error(arg->firstSourceLocation(),
"template argument is not a constant expression");
#endif
continue;
}

// ### need to set scope and location
auto templArg = control->newVariableSymbol(nullptr, {});
templArg->setInitializer(expr);
templArg->setType(control->add_const(expr->type));
templArg->setConstValue(value);
templateArguments.push_back(templArg);
} else if (auto typeArg = ast_cast<TypeTemplateArgumentAST>(arg)) {
auto type = typeArg->typeId->type;
// ### need to set scope and location
auto templArg = control->newTypeAliasSymbol(nullptr, {});
templArg->setType(type);
templateArguments.push_back(templArg);
}
}

return templateArguments;
}

} // namespace cxx
20 changes: 12 additions & 8 deletions src/parser/cxx/ast_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,17 @@ class ASTRewriter {
TranslationUnit* unit, List<TemplateArgumentAST*>* templateArgumentList,
TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*;

[[nodiscard]] static auto make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument>;

explicit ASTRewriter(TranslationUnit* unit, Scope* scope,
explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope,
const std::vector<TemplateArgument>& templateArguments);
~ASTRewriter();

[[nodiscard]] auto translationUnit() const -> TranslationUnit* {
return unit_;
}

[[nodiscard]] auto declaration(DeclarationAST* ast) -> DeclarationAST*;
[[nodiscard]] auto declaration(DeclarationAST* ast,
TemplateDeclarationAST* templateHead = nullptr)
-> DeclarationAST*;

private:
[[nodiscard]] auto templateArguments() const
Expand All @@ -76,6 +73,11 @@ class ASTRewriter {
[[nodiscard]] auto restrictedToDeclarations() const -> bool;
void setRestrictedToDeclarations(bool restrictedToDeclarations);

[[nodiscard]] static auto make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument>;

// run on the base nodes
[[nodiscard]] auto unit(UnitAST* ast) -> UnitAST*;
[[nodiscard]] auto statement(StatementAST* ast) -> StatementAST*;
Expand All @@ -85,7 +87,9 @@ class ASTRewriter {
[[nodiscard]] auto designator(DesignatorAST* ast) -> DesignatorAST*;
[[nodiscard]] auto templateParameter(TemplateParameterAST* ast)
-> TemplateParameterAST*;
[[nodiscard]] auto specifier(SpecifierAST* ast) -> SpecifierAST*;
[[nodiscard]] auto specifier(SpecifierAST* ast,
TemplateDeclarationAST* templateHead = nullptr)
-> SpecifierAST*;
[[nodiscard]] auto ptrOperator(PtrOperatorAST* ast) -> PtrOperatorAST*;
[[nodiscard]] auto coreDeclarator(CoreDeclaratorAST* ast)
-> CoreDeclaratorAST*;
Expand Down
Loading
Loading