Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/frontend/cxx/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool {

mlir::OpPrintingFlags flags;
if (cli.opt_g) {
flags.enableDebugInfo(true, true);
flags.enableDebugInfo(true, false);
}
ir.module->print(llvm::outs(), flags);
}
Expand Down
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
let results = (outs Cxx_FloatType:$result);
}

def Cxx_IntegralCastOp : Cxx_Op<"cast.integral"> {
let arguments = (ins AnyType:$value);

let results = (outs AnyType:$result);
}

//
// todo ops
//
Expand Down
20 changes: 20 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

// cxx
#include <cxx/control.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>

#include <format>
Expand All @@ -41,6 +42,25 @@ auto Codegen::currentBlockMightHaveTerminator() -> bool {
return block->mightHaveTerminator();
}

auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
auto var = symbol_cast<VariableSymbol>(symbol);
if (!var) return std::nullopt;

if (auto local = locals_.find(var); local != locals_.end()) {
return local->second;
}

auto type = convertType(var->type());
auto ptrType = builder_.getType<mlir::cxx::PointerType>(type);

auto loc = getLocation(var->location());
auto allocaOp = builder_.create<mlir::cxx::AllocaOp>(loc, ptrType);

locals_.emplace(var, allocaOp);

return allocaOp;
}

auto Codegen::getLocation(SourceLocation location) -> mlir::Location {
auto [filename, line, column] = unit_->tokenStartPosition(location);

Expand Down
4 changes: 4 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class Codegen {

[[nodiscard]] auto currentBlockMightHaveTerminator() -> bool;

[[nodiscard]] auto findOrCreateLocal(Symbol* symbol)
-> std::optional<mlir::Value>;

struct UnitVisitor;
struct DeclarationVisitor;
struct StatementVisitor;
Expand Down Expand Up @@ -277,6 +280,7 @@ class Codegen {
mlir::Block* exitBlock_ = nullptr;
mlir::cxx::AllocaOp exitValue_;
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
std::unordered_map<Symbol*, mlir::Value> locals_;
int count_ = 0;
};

Expand Down
40 changes: 33 additions & 7 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <cxx/ast.h>
#include <cxx/control.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>

// mlir
Expand Down Expand Up @@ -150,20 +151,44 @@ auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast)
-> DeclarationResult {
#if false
for (auto node : ListView{ast->attributeList}) {
auto value = gen(node);
auto value = gen.attributeSpecifier(node);
}

for (auto node : ListView{ast->declSpecifierList}) {
auto value = gen(node);
auto value = gen.specifier(node);
}

for (auto node : ListView{ast->initDeclaratorList}) {
auto value = gen(node);
auto value = gen.initDeclarator(node);
}

auto requiresClauseResult = gen(ast->requiresClause);
auto requiresClauseResult = gen.requiresClause(ast->requiresClause);
#endif

for (auto node : ListView{ast->initDeclaratorList}) {
auto var = symbol_cast<VariableSymbol>(node->symbol);
if (!var) continue;
if (!node->initializer) continue;

const auto loc = gen.getLocation(var->location());

auto local = gen.findOrCreateLocal(var);

if (!local.has_value()) {
gen.unit_->error(node->initializer->firstSourceLocation(),
std::format("cannot find local variable '{}'",
to_string(var->name())));
continue;
}

auto expressionResult = gen.expression(node->initializer);

const auto elementType = gen.convertType(var->type());

gen.builder_.create<mlir::cxx::StoreOp>(loc, expressionResult.value,
local.value());
}

return {};
}

Expand Down Expand Up @@ -319,10 +344,11 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
name += std::format("_{}", ++gen.count_);
}

const auto savedInsertionPoint = gen.builder_.saveInsertionPoint();
auto guard = mlir::OpBuilder::InsertionGuard(gen.builder_);

const auto loc = gen.getLocation(ast->symbol->location());

std::unordered_map<Symbol*, mlir::Value> locals;
auto func = gen.builder_.create<mlir::cxx::FuncOp>(loc, name, funcType);
auto entryBlock = &func.front();
auto exitBlock = gen.builder_.createBlock(&func.getBody());
Expand All @@ -345,6 +371,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
std::swap(gen.function_, func);
std::swap(gen.exitBlock_, exitBlock);
std::swap(gen.exitValue_, exitValue);
std::swap(gen.locals_, locals);

// generate code for the function body
auto functionBodyResult = gen.functionBody(ast->functionBody);
Expand Down Expand Up @@ -383,8 +410,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
std::swap(gen.function_, func);
std::swap(gen.exitBlock_, exitBlock);
std::swap(gen.exitValue_, exitValue);

gen.builder_.restoreInsertionPoint(savedInsertionPoint);
std::swap(gen.locals_, locals);

return {};
}
Expand Down
68 changes: 56 additions & 12 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

// cxx
#include <cxx/ast.h>
#include <cxx/ast_interpreter.h>
#include <cxx/literals.h>
#include <cxx/symbols.h>
#include <cxx/types.h>

namespace cxx {
Expand Down Expand Up @@ -255,6 +257,25 @@ auto Codegen::ExpressionVisitor::operator()(NestedExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(IdExpressionAST* ast)
-> ExpressionResult {
if (auto local = gen.findOrCreateLocal(ast->symbol)) {
return {local.value()};
}

if (auto enumerator = symbol_cast<EnumeratorSymbol>(ast->symbol)) {
auto value = enumerator->value().and_then([&](const ConstValue& value) {
ASTInterpreter interp{gen.unit_};
return interp.toInt(value);
});

if (value.has_value()) {
auto loc = gen.getLocation(ast->firstSourceLocation());
auto type = gen.convertType(enumerator->type());
auto op =
gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, *value);
return {op};
}
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

Expand Down Expand Up @@ -731,19 +752,44 @@ auto Codegen::ExpressionVisitor::operator()(DeleteExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(CastExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto typeIdResult = gen.typeId(ast->typeId);
auto expressionResult = gen.expression(ast->expression);
#endif

return {op};
return expressionResult;
}

auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast)
-> ExpressionResult {
auto loc = gen.getLocation(ast->firstSourceLocation());

switch (ast->castKind) {
case ImplicitCastKind::kLValueToRValueConversion: {
// generate a load
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);

auto op = gen.builder_.create<mlir::cxx::LoadOp>(loc, resultType,
expressionResult.value);

return {op};
}

case ImplicitCastKind::kIntegralConversion:
case ImplicitCastKind::kIntegralPromotion: {
// generate a cast
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);

auto op = gen.builder_.create<mlir::cxx::IntegralCastOp>(
loc, resultType, expressionResult.value);

return {op};
}

default:
break;

} // switch

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

Expand Down Expand Up @@ -891,14 +937,12 @@ auto Codegen::ExpressionVisitor::operator()(ConditionExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(EqualInitializerAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
// auto op =
// gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto expressionResult = gen.expression(ast->expression);
#endif

return {op};
return expressionResult;
}

auto Codegen::ExpressionVisitor::operator()(BracedInitListAST* ast)
Expand Down
6 changes: 1 addition & 5 deletions src/mlir/cxx/mlir/codegen_statements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,7 @@ void Codegen::StatementVisitor::operator()(GotoStatementAST* ast) {
}

void Codegen::StatementVisitor::operator()(DeclarationStatementAST* ast) {
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto declarationResult = gen(ast->declaration);
#endif
auto declarationResult = gen.declaration(ast->declaration);
}

void Codegen::StatementVisitor::operator()(TryBlockStatementAST* ast) {
Expand Down
4 changes: 3 additions & 1 deletion src/mlir/cxx/mlir/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ auto Codegen::ConvertType::operator()(const OverloadSetType* type)

auto Codegen::ConvertType::operator()(const BuiltinVaListType* type)
-> mlir::Type {
return getExprType();
// todo: toolchain specific
auto voidType = gen.builder_.getType<mlir::cxx::VoidType>();
return gen.builder_.getType<mlir::cxx::PointerType>(voidType);
}

} // namespace cxx
61 changes: 59 additions & 2 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
auto resultType = LLVM::LLVMPointerType::get(context);
auto elementType = typeConverter->convertType(ptrTy.getElementType());

if (!elementType) {
return rewriter.notifyMatchFailure(
op, "failed to convert element type of alloca");
}

auto size = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));

Expand Down Expand Up @@ -240,6 +245,58 @@ class FloatConstantOpLowering
}
};

class IntegralCastOpLowering : public OpConversionPattern<cxx::IntegralCastOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::IntegralCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert integral cast type");
}

const auto sourceType = dyn_cast<cxx::IntegerType>(op.getValue().getType());
const auto targetType = dyn_cast<cxx::IntegerType>(op.getType());
const auto isSigned = targetType.getIsSigned();

if (sourceType.getWidth() == targetType.getWidth()) {
// no conversion needed, just replace the op with the value
rewriter.replaceOp(op, adaptor.getValue());
return success();
}

if (targetType.getWidth() < sourceType.getWidth()) {
// truncation
if (isSigned) {
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType,
adaptor.getValue());
} else {
rewriter.replaceOpWithNewOp<LLVM::ZExtOp>(op, resultType,
adaptor.getValue());
}
return success();
}

// extension

if (isSigned) {
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, resultType,
adaptor.getValue());
} else {
rewriter.replaceOpWithNewOp<LLVM::ZExtOp>(op, resultType,
adaptor.getValue());
}

return success();
}
};

class CxxToLLVMLoweringPass
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
public:
Expand Down Expand Up @@ -328,8 +385,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
RewritePatternSet patterns(context);
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering>(typeConverter,
context);
IntConstantOpLowering, FloatConstantOpLowering,
IntegralCastOpLowering>(typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
typeConverter);
Expand Down
Loading