diff --git a/src/mlir/cxx/mlir/codegen_units.cc b/src/mlir/cxx/mlir/codegen_units.cc index f627219d..0a55ab3d 100644 --- a/src/mlir/cxx/mlir/codegen_units.cc +++ b/src/mlir/cxx/mlir/codegen_units.cc @@ -22,10 +22,28 @@ // cxx #include +#include #include namespace cxx { +namespace { +struct ForEachExternalDefinition final : ASTVisitor { + std::function functionCallback; + + void visit(TemplateDeclarationAST*) override { + // Skip template declarations, we only want to visit function definitions. + } + + void visit(FunctionDefinitionAST* ast) override { + if (functionCallback) functionCallback(ast); + + ASTVisitor::visit(ast); + } +}; + +} // namespace + struct Codegen::UnitVisitor { Codegen& gen; @@ -46,8 +64,15 @@ auto Codegen::UnitVisitor::operator()(TranslationUnitAST* ast) -> UnitResult { std::swap(gen.module_, module); + ForEachExternalDefinition forEachExternalDefinition; + + forEachExternalDefinition.functionCallback = + [&](FunctionDefinitionAST* function) { + auto functionResult = gen.declaration(function); + }; + for (auto node : ListView{ast->declarationList}) { - auto value = gen.declaration(node); + forEachExternalDefinition.accept(node); } std::swap(gen.module_, module); diff --git a/src/parser/cxx/binder.cc b/src/parser/cxx/binder.cc index 7ec35e08..1874116e 100644 --- a/src/parser/cxx/binder.cc +++ b/src/parser/cxx/binder.cc @@ -320,6 +320,14 @@ void Binder::bind(ParameterDeclarationAST* ast, const Decl& decl, bool inTemplateParameters) { ast->type = getDeclaratorType(unit_, ast->declarator, decl.specs.type()); + // decay the type of the parameters + if (control()->is_array(ast->type)) + ast->type = control()->add_pointer(control()->remove_extent(ast->type)); + else if (control()->is_function(ast->type)) + ast->type = control()->add_pointer(ast->type); + else if (control()->is_scalar(ast->type)) + ast->type = control()->remove_cv(ast->type); + if (auto declId = decl.declaratorId; declId && declId->unqualifiedId) { auto paramName = get_name(control(), declId->unqualifiedId); if (auto identifier = name_cast(paramName)) { diff --git a/src/parser/cxx/external_name_encoder.cc b/src/parser/cxx/external_name_encoder.cc index 4b3e0966..aa9ab398 100644 --- a/src/parser/cxx/external_name_encoder.cc +++ b/src/parser/cxx/external_name_encoder.cc @@ -236,7 +236,7 @@ struct ExternalNameEncoder::EncodeType { // todo: "Y" prefix for the bare function type encodes extern "C" encoder.out("F"); - encoder.encodeBareFunctionType(type); + encoder.encodeBareFunctionType(type, /*includeReturnType=*/true); if (type->refQualifier() == RefQualifier::kLvalue) encoder.out("R");