Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

def Cxx_Dialect : Dialect {
let name = "cxx";
Expand Down Expand Up @@ -103,17 +104,13 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {

let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];

let extraClassDeclaration = [{
auto getArgumentTypes() -> ArrayRef<Type> { return getFunctionType().getInputs(); }
auto getResultTypes() -> ArrayRef<Type> { return getFunctionType().getResults(); }
auto getCallableRegion() -> Region* { return &getBody(); }
}];

let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
Expand All @@ -128,6 +125,17 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
let hasVerifier = 0;
}

def Cxx_CallOp : Cxx_Op<"call"> {
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);

let results = (outs AnyType);
}

def Cxx_AllocaOp : Cxx_Op<"alloca"> {
let arguments = (ins);

Expand Down
48 changes: 48 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

// cxx
#include <cxx/control.h>
#include <cxx/external_name_encoder.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>

// mlir
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
Expand Down Expand Up @@ -79,6 +81,52 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
return allocaOp;
}

auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
-> mlir::cxx::FuncOp {
if (auto it = funcOps_.find(functionSymbol); it != funcOps_.end()) {
return it->second;
}

const auto functionType = type_cast<FunctionType>(functionSymbol->type());
const auto returnType = functionType->returnType();
const auto needsExitValue = !control()->is_void(returnType);

std::vector<mlir::Type> inputTypes;
std::vector<mlir::Type> resultTypes;

for (auto paramTy : functionType->parameterTypes()) {
inputTypes.push_back(convertType(paramTy));
}

if (needsExitValue) {
resultTypes.push_back(convertType(returnType));
}

auto funcType = builder_.getFunctionType(inputTypes, resultTypes);

std::string name;

if (functionSymbol->hasCLinkage()) {
name = to_string(functionSymbol->name());
} else {
ExternalNameEncoder encoder;
name = encoder.encode(functionSymbol);
}

const auto loc = getLocation(functionSymbol->location());

auto guard = mlir::OpBuilder::InsertionGuard(builder_);

builder_.setInsertionPointToStart(module_.getBody());

auto func = builder_.create<mlir::cxx::FuncOp>(
loc, name, funcType, mlir::ArrayAttr{}, mlir::ArrayAttr{});

funcOps_.insert_or_assign(functionSymbol, func);

return func;
}

auto Codegen::getLocation(SourceLocation location) -> mlir::Location {
auto [filename, line, column] = unit_->tokenStartPosition(location);

Expand Down
4 changes: 4 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ class Codegen {

[[nodiscard]] auto currentBlockMightHaveTerminator() -> bool;

[[nodiscard]] auto findOrCreateFunction(FunctionSymbol* functionSymbol)
-> mlir::cxx::FuncOp;

[[nodiscard]] auto findOrCreateLocal(Symbol* symbol)
-> std::optional<mlir::Value>;

Expand Down Expand Up @@ -308,6 +311,7 @@ class Codegen {
mlir::cxx::AllocaOp exitValue_;
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
std::unordered_map<Symbol*, mlir::Value> locals_;
std::unordered_map<FunctionSymbol*, mlir::cxx::FuncOp> funcOps_;
Loop loop_;
int count_ = 0;
};
Expand Down
42 changes: 9 additions & 33 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,46 +316,20 @@ auto Codegen::DeclarationVisitor::operator()(OpaqueEnumDeclarationAST* ast)
auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
-> DeclarationResult {
auto functionSymbol = ast->symbol;

auto func = gen.findOrCreateFunction(functionSymbol);
const auto functionType = type_cast<FunctionType>(functionSymbol->type());
const auto returnType = functionType->returnType();
const auto needsExitValue = !gen.control()->is_void(returnType);

std::vector<mlir::Type> inputTypes;
std::vector<mlir::Type> resultTypes;

for (auto paramTy : functionType->parameterTypes()) {
inputTypes.push_back(gen.convertType(paramTy));
}

if (needsExitValue) {
resultTypes.push_back(gen.convertType(returnType));
}

auto funcType = gen.builder_.getFunctionType(inputTypes, resultTypes);

std::vector<std::string> path;
for (Symbol* symbol = ast->symbol; symbol;
symbol = symbol->enclosingSymbol()) {
if (!symbol->name()) continue;
path.push_back(to_string(symbol->name()));
}

std::string name;
auto loc = gen.getLocation(ast->firstSourceLocation());

if (ast->symbol->hasCLinkage()) {
name = to_string(ast->symbol->name());
} else {
ExternalNameEncoder encoder;
name = encoder.encode(ast->symbol);
// Add the function body.
auto entryBlock = gen.builder_.createBlock(&func.getBody());
for (const auto& input : func.getFunctionType().getInputs()) {
entryBlock->addArgument(input, loc);
}

auto guard = mlir::OpBuilder::InsertionGuard(gen.builder_);

const auto loc = gen.getLocation(ast->symbol->location());

std::unordered_map<Symbol*, mlir::Value> locals;
auto func = gen.builder_.create<mlir::cxx::FuncOp>(loc, name, funcType);
auto entryBlock = &func.front();
auto exitBlock = gen.builder_.createBlock(&func.getBody());
mlir::cxx::AllocaOp exitValue;

Expand All @@ -370,6 +344,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
exitValue = gen.builder_.create<mlir::cxx::AllocaOp>(exitValueLoc, ptrType);
}

std::unordered_map<Symbol*, mlir::Value> locals;

// function state
std::swap(gen.function_, func);
std::swap(gen.exitBlock_, exitBlock);
Expand Down
37 changes: 37 additions & 0 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,43 @@ auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
-> ExpressionResult {
auto check_direct_call = [&]() -> ExpressionResult {
auto func = ast->baseExpression;

while (auto nested = ast_cast<NestedExpressionAST>(func)) {
func = nested->expression;
}

auto id = ast_cast<IdExpressionAST>(func);
if (!id) return {};

auto functionSymbol = symbol_cast<FunctionSymbol>(id->symbol);

if (!functionSymbol) return {};

auto funcOp = gen.findOrCreateFunction(functionSymbol);

mlir::SmallVector<mlir::Value> arguments;
for (auto node : ListView{ast->expressionList}) {
auto value = gen.expression(node);
arguments.push_back(value.value);
}

auto loc = gen.getLocation(ast->lparenLoc);

auto functionType = type_cast<FunctionType>(functionSymbol->type());
auto resultType = gen.convertType(functionType->returnType());
auto op = gen.builder_.create<mlir::cxx::CallOp>(
loc, resultType, funcOp.getSymName(), arguments, mlir::ArrayAttr{},
mlir::ArrayAttr{});

return {op};
};

if (auto op = check_direct_call(); op.value) {
return op;
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

Expand Down
5 changes: 0 additions & 5 deletions src/mlir/cxx/mlir/cxx_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ void CxxDialect::initialize() {
addInterface<CxxGenerateAliases>();
}

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
Expand Down
40 changes: 39 additions & 1 deletion src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
auto func = rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), op.getSymName(),
llvmFuncType);

if (op.getBody().empty()) {
func.setLinkage(LLVM::linkage::Linkage::External);
}

rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());

rewriter.eraseOp(op);
Expand All @@ -101,6 +105,39 @@ class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
}
};

class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();

SmallVector<Type> argumentTypes;
for (auto argType : op.getOperandTypes()) {
auto convertedType = typeConverter->convertType(argType);
if (!convertedType) {
return rewriter.notifyMatchFailure(
op, "failed to convert call argument type");
}
argumentTypes.push_back(convertedType);
}

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"failed to convert call result types");
}

auto llvmCallOp = rewriter.create<LLVM::CallOp>(
op.getLoc(), resultType, adaptor.getCallee(), adaptor.getInputs());

rewriter.replaceOp(op, llvmCallOp.getResults());
return success();
}
};

class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -908,7 +945,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
RewritePatternSet patterns(context);

// function operations
patterns.insert<FuncOpLowering, ReturnOpLowering>(typeConverter, context);
patterns.insert<FuncOpLowering, ReturnOpLowering, CallOpLowering>(
typeConverter, context);

// memory operations
patterns.insert<AllocaOpLowering, LoadOpLowering, StoreOpLowering>(
Expand Down
Loading