Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/frontend/cxx/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool {

if (!shouldExit) {
unit.parse(ParserConfiguration{
.checkTypes = cli.opt_fcheck,
.checkTypes = cli.opt_fcheck || unit.language() == LanguageKind::kC,
.fuzzyTemplateResolution = true,
.reflect = !cli.opt_fno_reflect,
});
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def Cxx_CallOp : Cxx_Op<"call"> {
OptionalAttr<DictArrayAttr>:$res_attrs
);

let results = (outs AnyType);
let results = (outs Optional<AnyType>:$result);
}

def Cxx_AllocaOp : Cxx_Op<"alloca"> {
Expand Down
16 changes: 10 additions & 6 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
-> ExpressionResult {
auto check_direct_call = [&]() -> ExpressionResult {
auto check_direct_call = [&]() -> std::optional<ExpressionResult> {
auto func = ast->baseExpression;

while (auto nested = ast_cast<NestedExpressionAST>(func)) {
Expand All @@ -490,16 +490,20 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
auto loc = gen.getLocation(ast->lparenLoc);

auto functionType = type_cast<FunctionType>(functionSymbol->type());
auto resultType = gen.convertType(functionType->returnType());
mlir::SmallVector<mlir::Type> resultTypes;
if (!control()->is_void(functionType->returnType())) {
resultTypes.push_back(gen.convertType(functionType->returnType()));
}

auto op = gen.builder_.create<mlir::cxx::CallOp>(
loc, resultType, funcOp.getSymName(), arguments, mlir::ArrayAttr{},
loc, resultTypes, funcOp.getSymName(), arguments, mlir::ArrayAttr{},
mlir::ArrayAttr{});

return {op};
return ExpressionResult{op.getResult()};
};

if (auto op = check_direct_call(); op.value) {
return op;
if (auto op = check_direct_call(); op.has_value()) {
return *op;
}

auto op =
Expand Down
12 changes: 8 additions & 4 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
argumentTypes.push_back(convertedType);
}

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
SmallVector<Type> resultTypes;
if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) {
return rewriter.notifyMatchFailure(op,
"failed to convert call result types");
}

auto llvmCallOp = rewriter.create<LLVM::CallOp>(
op.getLoc(), resultType, adaptor.getCallee(), adaptor.getInputs());
op.getLoc(), resultTypes, adaptor.getCallee(), adaptor.getInputs());

rewriter.replaceOp(op, llvmCallOp.getResults());
rewriter.replaceOp(op, llvmCallOp);
return success();
}
};
Expand Down Expand Up @@ -875,6 +875,10 @@ void CxxToLLVMLoweringPass::runOnOperation() {
// set up the type converter
LLVMTypeConverter typeConverter{context};

typeConverter.addConversion([](cxx::VoidType type) {
return LLVM::LLVMVoidType::get(type.getContext());
});

typeConverter.addConversion([](cxx::BoolType type) {
// todo: i8/i32 for data and i1 for control flow
return IntegerType::get(type.getContext(), 1);
Expand Down
2 changes: 2 additions & 0 deletions src/parser/cxx/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3288,6 +3288,8 @@ void Parser::parse_init_statement(StatementAST*& yyast) {

void Parser::parse_condition(ExpressionAST*& yyast, const ExprContext& ctx) {
auto lookat_condition = [&] {
if (!is_parsing_cxx()) return false;

LookaheadParser lookahead{this};

List<AttributeSpecifierAST*>* attributes = nullptr;
Expand Down
91 changes: 82 additions & 9 deletions src/parser/cxx/type_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ struct TypeChecker::Visitor {
[[nodiscard]] auto check_static_cast(ExpressionAST*& expression,
const Type* targetType) -> bool;

[[nodiscard]] auto check_const_cast(ExpressionAST*& expression,
const Type* targetType) -> bool;

[[nodiscard]] auto check_cast_to_derived(ExpressionAST* expression,
const Type* targetType) -> bool;

Expand Down Expand Up @@ -467,13 +470,13 @@ void TypeChecker::Visitor::operator()(CallExpressionAST* ast) {

if (!functionType) {
if (control()->is_pointer(ast->baseExpression->type)) {
// ressolve pointer to function type
// resolve pointer to function type
functionType = type_cast<FunctionType>(
control()->get_element_type(ast->baseExpression->type));
}

if (functionType && is_parsing_c()) {
(void)ensure_prvalue(ast->baseExpression);
if (functionType && is_parsing_c()) {
(void)ensure_prvalue(ast->baseExpression);
}
}
}

Expand All @@ -489,6 +492,46 @@ void TypeChecker::Visitor::operator()(CallExpressionAST* ast) {
}

// TODO: check the arguments
if (is_parsing_c()) {
const auto& argumentTypes = functionType->parameterTypes();

int argc = 0;
for (auto it = ast->expressionList; it; it = it->next) {
if (!it->value) {
error(ast->firstSourceLocation(),
"invalid call with null argument expression");
continue;
}

if (argc >= argumentTypes.size()) {
if (functionType->isVariadic()) {
// do the promotion for the variadic arguments
(void)ensure_prvalue(it->value);
adjust_cv(it->value);

if (integral_promotion(it->value)) continue;
if (floating_point_promotion(it->value)) continue;

continue;
}

error(it->value->firstSourceLocation(),
std::format("too many arguments for function of type '{}'",
to_string(functionType)));
break;
}

auto targetType = argumentTypes[argc];
++argc;

if (!implicit_conversion(it->value, targetType)) {
error(it->value->firstSourceLocation(),
std::format("invalid argument of type '{}' for parameter of type "
"'{}'",
to_string(it->value->type), to_string(targetType)));
}
}
}

ast->type = functionType->returnType();

Expand Down Expand Up @@ -579,13 +622,23 @@ void TypeChecker::Visitor::operator()(CppCastExpressionAST* ast) {
check_cpp_cast_expression(ast);

switch (check.unit_->tokenKind(ast->castLoc)) {
case TokenKind::T_STATIC_CAST:
case TokenKind::T_STATIC_CAST: {
if (check_static_cast(ast->expression, ast->type)) break;
error(
ast->firstSourceLocation(),
std::format("invalid static_cast of '{}' to '{}'",
to_string(ast->expression->type), to_string(ast->type)));
break;
}

case TokenKind::T_CONST_CAST: {
if (check_const_cast(ast->expression, ast->type)) break;
error(
ast->firstSourceLocation(),
std::format("invalid const_cast of '{}' to '{}'",
to_string(ast->expression->type), to_string(ast->type)));
break;
}

default:
break;
Expand Down Expand Up @@ -653,6 +706,11 @@ auto TypeChecker::Visitor::check_static_cast(ExpressionAST*& expression,
return true;
}

auto TypeChecker::Visitor::check_const_cast(ExpressionAST*& expression,
const Type* targetType) -> bool {
return false;
}

auto TypeChecker::Visitor::check_cast_to_derived(ExpressionAST* expression,
const Type* targetType)
-> bool {
Expand Down Expand Up @@ -967,6 +1025,7 @@ void TypeChecker::Visitor::operator()(CastExpressionAST* ast) {
}

if (check_static_cast(ast->expression, ast->type)) return;
if (check_const_cast(ast->expression, ast->type)) return;
// check the other casts
}

Expand Down Expand Up @@ -1659,9 +1718,11 @@ auto TypeChecker::Visitor::pointer_conversion(ExpressionAST*& expr,
const auto destinationPointerType = type_cast<PointerType>(destinationType);
if (!destinationPointerType) return false;

if (control()->get_cv_qualifiers(pointerType->elementType()) !=
control()->get_cv_qualifiers(destinationPointerType->elementType()))
return false;
auto sourceCv = control()->get_cv_qualifiers(pointerType->elementType());
auto targetCv =
control()->get_cv_qualifiers(destinationPointerType->elementType());

if (!check_cv_qualifiers(targetCv, sourceCv)) return false;

if (!control()->is_void(destinationPointerType->elementType()))
return false;
Expand Down Expand Up @@ -1812,7 +1873,19 @@ auto TypeChecker::Visitor::temporary_materialization_conversion(
auto TypeChecker::Visitor::qualification_conversion(ExpressionAST*& expr,
const Type* destinationType)
-> bool {
return false;
auto type = get_qualification_combined_type(expr->type, destinationType);
if (!type) return false;

if (!control()->is_same(destinationType, type)) return false;

auto cast = make_node<ImplicitCastExpressionAST>(arena());
cast->castKind = ImplicitCastKind::kQualificationConversion;
cast->expression = expr;
cast->type = destinationType;
cast->valueCategory = expr->valueCategory;
expr = cast;

return true;
}

auto TypeChecker::Visitor::ensure_prvalue(ExpressionAST*& expr) -> bool {
Expand Down
Loading