diff --git a/src/frontend/cxx/frontend.cc b/src/frontend/cxx/frontend.cc index 7f1e0fd5..78fc6531 100644 --- a/src/frontend/cxx/frontend.cc +++ b/src/frontend/cxx/frontend.cc @@ -340,7 +340,7 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool { if (!shouldExit) { unit.parse(ParserConfiguration{ - .checkTypes = cli.opt_fcheck, + .checkTypes = cli.opt_fcheck || unit.language() == LanguageKind::kC, .fuzzyTemplateResolution = true, .reflect = !cli.opt_fno_reflect, }); diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 765d80bc..9a6dd7a3 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -133,7 +133,7 @@ def Cxx_CallOp : Cxx_Op<"call"> { OptionalAttr:$res_attrs ); - let results = (outs AnyType); + let results = (outs Optional:$result); } def Cxx_AllocaOp : Cxx_Op<"alloca"> { diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 44903236..dab589e2 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -465,7 +465,7 @@ auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast) -> ExpressionResult { - auto check_direct_call = [&]() -> ExpressionResult { + auto check_direct_call = [&]() -> std::optional { auto func = ast->baseExpression; while (auto nested = ast_cast(func)) { @@ -490,16 +490,20 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast) auto loc = gen.getLocation(ast->lparenLoc); auto functionType = type_cast(functionSymbol->type()); - auto resultType = gen.convertType(functionType->returnType()); + mlir::SmallVector resultTypes; + if (!control()->is_void(functionType->returnType())) { + resultTypes.push_back(gen.convertType(functionType->returnType())); + } + auto op = gen.builder_.create( - loc, resultType, funcOp.getSymName(), arguments, mlir::ArrayAttr{}, + loc, resultTypes, funcOp.getSymName(), arguments, mlir::ArrayAttr{}, mlir::ArrayAttr{}); - return {op}; + return ExpressionResult{op.getResult()}; }; - if (auto op = check_direct_call(); op.value) { - return op; + if (auto op = check_direct_call(); op.has_value()) { + return *op; } auto op = diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 76a7038f..ad65af0f 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -124,16 +124,16 @@ class CallOpLowering : public OpConversionPattern { argumentTypes.push_back(convertedType); } - auto resultType = typeConverter->convertType(op.getType()); - if (!resultType) { + SmallVector resultTypes; + if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) { return rewriter.notifyMatchFailure(op, "failed to convert call result types"); } auto llvmCallOp = rewriter.create( - op.getLoc(), resultType, adaptor.getCallee(), adaptor.getInputs()); + op.getLoc(), resultTypes, adaptor.getCallee(), adaptor.getInputs()); - rewriter.replaceOp(op, llvmCallOp.getResults()); + rewriter.replaceOp(op, llvmCallOp); return success(); } }; @@ -875,6 +875,10 @@ void CxxToLLVMLoweringPass::runOnOperation() { // set up the type converter LLVMTypeConverter typeConverter{context}; + typeConverter.addConversion([](cxx::VoidType type) { + return LLVM::LLVMVoidType::get(type.getContext()); + }); + typeConverter.addConversion([](cxx::BoolType type) { // todo: i8/i32 for data and i1 for control flow return IntegerType::get(type.getContext(), 1); diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 46ea9882..07a8f9be 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -3288,6 +3288,8 @@ void Parser::parse_init_statement(StatementAST*& yyast) { void Parser::parse_condition(ExpressionAST*& yyast, const ExprContext& ctx) { auto lookat_condition = [&] { + if (!is_parsing_cxx()) return false; + LookaheadParser lookahead{this}; List* attributes = nullptr; diff --git a/src/parser/cxx/type_checker.cc b/src/parser/cxx/type_checker.cc index 4e985020..19661448 100644 --- a/src/parser/cxx/type_checker.cc +++ b/src/parser/cxx/type_checker.cc @@ -140,6 +140,9 @@ struct TypeChecker::Visitor { [[nodiscard]] auto check_static_cast(ExpressionAST*& expression, const Type* targetType) -> bool; + [[nodiscard]] auto check_const_cast(ExpressionAST*& expression, + const Type* targetType) -> bool; + [[nodiscard]] auto check_cast_to_derived(ExpressionAST* expression, const Type* targetType) -> bool; @@ -467,13 +470,13 @@ void TypeChecker::Visitor::operator()(CallExpressionAST* ast) { if (!functionType) { if (control()->is_pointer(ast->baseExpression->type)) { - // ressolve pointer to function type + // resolve pointer to function type functionType = type_cast( control()->get_element_type(ast->baseExpression->type)); - } - if (functionType && is_parsing_c()) { - (void)ensure_prvalue(ast->baseExpression); + if (functionType && is_parsing_c()) { + (void)ensure_prvalue(ast->baseExpression); + } } } @@ -489,6 +492,46 @@ void TypeChecker::Visitor::operator()(CallExpressionAST* ast) { } // TODO: check the arguments + if (is_parsing_c()) { + const auto& argumentTypes = functionType->parameterTypes(); + + int argc = 0; + for (auto it = ast->expressionList; it; it = it->next) { + if (!it->value) { + error(ast->firstSourceLocation(), + "invalid call with null argument expression"); + continue; + } + + if (argc >= argumentTypes.size()) { + if (functionType->isVariadic()) { + // do the promotion for the variadic arguments + (void)ensure_prvalue(it->value); + adjust_cv(it->value); + + if (integral_promotion(it->value)) continue; + if (floating_point_promotion(it->value)) continue; + + continue; + } + + error(it->value->firstSourceLocation(), + std::format("too many arguments for function of type '{}'", + to_string(functionType))); + break; + } + + auto targetType = argumentTypes[argc]; + ++argc; + + if (!implicit_conversion(it->value, targetType)) { + error(it->value->firstSourceLocation(), + std::format("invalid argument of type '{}' for parameter of type " + "'{}'", + to_string(it->value->type), to_string(targetType))); + } + } + } ast->type = functionType->returnType(); @@ -579,13 +622,23 @@ void TypeChecker::Visitor::operator()(CppCastExpressionAST* ast) { check_cpp_cast_expression(ast); switch (check.unit_->tokenKind(ast->castLoc)) { - case TokenKind::T_STATIC_CAST: + case TokenKind::T_STATIC_CAST: { if (check_static_cast(ast->expression, ast->type)) break; error( ast->firstSourceLocation(), std::format("invalid static_cast of '{}' to '{}'", to_string(ast->expression->type), to_string(ast->type))); break; + } + + case TokenKind::T_CONST_CAST: { + if (check_const_cast(ast->expression, ast->type)) break; + error( + ast->firstSourceLocation(), + std::format("invalid const_cast of '{}' to '{}'", + to_string(ast->expression->type), to_string(ast->type))); + break; + } default: break; @@ -653,6 +706,11 @@ auto TypeChecker::Visitor::check_static_cast(ExpressionAST*& expression, return true; } +auto TypeChecker::Visitor::check_const_cast(ExpressionAST*& expression, + const Type* targetType) -> bool { + return false; +} + auto TypeChecker::Visitor::check_cast_to_derived(ExpressionAST* expression, const Type* targetType) -> bool { @@ -967,6 +1025,7 @@ void TypeChecker::Visitor::operator()(CastExpressionAST* ast) { } if (check_static_cast(ast->expression, ast->type)) return; + if (check_const_cast(ast->expression, ast->type)) return; // check the other casts } @@ -1659,9 +1718,11 @@ auto TypeChecker::Visitor::pointer_conversion(ExpressionAST*& expr, const auto destinationPointerType = type_cast(destinationType); if (!destinationPointerType) return false; - if (control()->get_cv_qualifiers(pointerType->elementType()) != - control()->get_cv_qualifiers(destinationPointerType->elementType())) - return false; + auto sourceCv = control()->get_cv_qualifiers(pointerType->elementType()); + auto targetCv = + control()->get_cv_qualifiers(destinationPointerType->elementType()); + + if (!check_cv_qualifiers(targetCv, sourceCv)) return false; if (!control()->is_void(destinationPointerType->elementType())) return false; @@ -1812,7 +1873,19 @@ auto TypeChecker::Visitor::temporary_materialization_conversion( auto TypeChecker::Visitor::qualification_conversion(ExpressionAST*& expr, const Type* destinationType) -> bool { - return false; + auto type = get_qualification_combined_type(expr->type, destinationType); + if (!type) return false; + + if (!control()->is_same(destinationType, type)) return false; + + auto cast = make_node(arena()); + cast->castKind = ImplicitCastKind::kQualificationConversion; + cast->expression = expr; + cast->type = destinationType; + cast->valueCategory = expr->valueCategory; + expr = cast; + + return true; } auto TypeChecker::Visitor::ensure_prvalue(ExpressionAST*& expr) -> bool {