Skip to content

Commit 86ec5a2

Browse files
committed
Lower floats and struct types to the LLVM dialect
1 parent 2bda655 commit 86ec5a2

File tree

5 files changed

+137
-17
lines changed

5 files changed

+137
-17
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include "mlir/IR/AttrTypeBase.td"
2323
include "mlir/IR/SymbolInterfaces.td"
2424
include "mlir/Interfaces/FunctionInterfaces.td"
2525
include "mlir/Interfaces/SideEffectInterfaces.td"
26+
include "mlir/IR/BuiltinAttributeInterfaces.td"
2627

2728
def Cxx_Dialect : Dialect {
2829
let name = "cxx";
@@ -148,23 +149,23 @@ def Cxx_StoreOp : Cxx_Op<"store"> {
148149
def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
149150
Pure
150151
]> {
151-
let arguments = (ins BoolAttr:$value);
152+
let arguments = (ins BoolProp:$value);
152153

153154
let results = (outs Cxx_BoolType:$result);
154155
}
155156

156157
def Cxx_IntConstantOp : Cxx_Op<"constant.int", [
157158
Pure
158159
]> {
159-
let arguments = (ins I64Attr:$value);
160+
let arguments = (ins I64Prop:$value);
160161

161162
let results = (outs Cxx_IntegerType:$result);
162163
}
163164

164165
def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
165166
Pure
166167
]> {
167-
let arguments = (ins F64Attr:$value);
168+
let arguments = (ins TypedAttrInterface:$value);
168169

169170
let results = (outs Cxx_FloatType:$result);
170171
}
@@ -174,7 +175,7 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
174175
//
175176

176177
def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
177-
let arguments = (ins StrAttr:$message);
178+
let arguments = (ins StringProp:$message);
178179
let results = (outs Cxx_ExprType:$result);
179180
let assemblyFormat = "$message attr-dict `:` type($result)";
180181
let builders =
@@ -185,7 +186,7 @@ def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
185186
}
186187

187188
def Cxx_TodoStmtOp : Cxx_Op<"todo.stmt"> {
188-
let arguments = (ins StrAttr:$message);
189+
let arguments = (ins StringProp:$message);
189190
let results = (outs);
190191
let assemblyFormat = "$message attr-dict";
191192
}

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
// cxx
2424
#include <cxx/ast.h>
2525
#include <cxx/literals.h>
26+
#include <cxx/types.h>
2627

2728
namespace cxx {
2829

@@ -133,8 +134,7 @@ auto Codegen::ExpressionVisitor::operator()(CharLiteralExpressionAST* ast)
133134
auto loc = gen.getLocation(ast->literalLoc);
134135

135136
auto type = gen.convertType(ast->type);
136-
auto value = gen.builder_.getI64IntegerAttr(ast->literal->charValue());
137-
137+
auto value = std::int64_t(ast->literal->charValue());
138138
auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);
139139

140140
return {op};
@@ -145,9 +145,9 @@ auto Codegen::ExpressionVisitor::operator()(BoolLiteralExpressionAST* ast)
145145
auto loc = gen.getLocation(ast->literalLoc);
146146

147147
auto type = gen.convertType(ast->type);
148-
auto value = gen.builder_.getBoolAttr(ast->isTrue);
149148

150-
auto op = gen.builder_.create<mlir::cxx::BoolConstantOp>(loc, type, value);
149+
auto op =
150+
gen.builder_.create<mlir::cxx::BoolConstantOp>(loc, type, ast->isTrue);
151151

152152
return {op};
153153
}
@@ -157,7 +157,7 @@ auto Codegen::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast)
157157
auto loc = gen.getLocation(ast->literalLoc);
158158

159159
auto type = gen.convertType(ast->type);
160-
auto value = gen.builder_.getI64IntegerAttr(ast->literal->integerValue());
160+
auto value = ast->literal->integerValue();
161161

162162
auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);
163163

@@ -169,7 +169,25 @@ auto Codegen::ExpressionVisitor::operator()(FloatLiteralExpressionAST* ast)
169169
auto loc = gen.getLocation(ast->literalLoc);
170170

171171
auto type = gen.convertType(ast->type);
172-
auto value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());
172+
173+
mlir::TypedAttr value;
174+
175+
switch (ast->type->kind()) {
176+
case TypeKind::kFloat:
177+
value = gen.builder_.getF32FloatAttr(ast->literal->floatValue());
178+
break;
179+
case TypeKind::kDouble:
180+
value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());
181+
break;
182+
case TypeKind::kLongDouble:
183+
value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());
184+
break;
185+
default:
186+
// Handle other float types if necessary
187+
auto op = gen.emitTodoExpr(ast->firstSourceLocation(),
188+
"unsupported float type");
189+
return {op};
190+
}
173191

174192
auto op = gen.builder_.create<mlir::cxx::FloatConstantOp>(loc, type, value);
175193

src/mlir/cxx/mlir/convert_type.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,14 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {
299299
}
300300

301301
auto Codegen::ConvertType::operator()(const EnumType* type) -> mlir::Type {
302-
return getExprType();
302+
if (type->underlyingType()) return gen.convertType(type->underlyingType());
303+
return gen.builder_.getType<mlir::cxx::IntegerType>(32, true);
303304
}
304305

305306
auto Codegen::ConvertType::operator()(const ScopedEnumType* type)
306307
-> mlir::Type {
307-
return getExprType();
308+
if (type->underlyingType()) return gen.convertType(type->underlyingType());
309+
return gen.builder_.getType<mlir::cxx::IntegerType>(32, true);
308310
}
309311

310312
auto Codegen::ConvertType::operator()(const MemberObjectPointerType* type)

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,29 @@ class StoreOpLowering : public OpConversionPattern<cxx::StoreOp> {
170170
}
171171
};
172172

173+
class BoolConstantOpLowering : public OpConversionPattern<cxx::BoolConstantOp> {
174+
public:
175+
using OpConversionPattern::OpConversionPattern;
176+
177+
auto matchAndRewrite(cxx::BoolConstantOp op, OpAdaptor adaptor,
178+
ConversionPatternRewriter &rewriter) const
179+
-> LogicalResult override {
180+
auto typeConverter = getTypeConverter();
181+
auto context = getContext();
182+
183+
auto resultType = typeConverter->convertType(op.getType());
184+
if (!resultType) {
185+
return rewriter.notifyMatchFailure(
186+
op, "failed to convert boolean constant type");
187+
}
188+
189+
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
190+
adaptor.getValue());
191+
192+
return success();
193+
}
194+
};
195+
173196
class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
174197
public:
175198
using OpConversionPattern::OpConversionPattern;
@@ -186,8 +209,32 @@ class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
186209
op, "failed to convert integer constant type");
187210
}
188211

189-
auto valueAttr = adaptor.getValueAttr();
190-
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType, valueAttr);
212+
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
213+
adaptor.getValue());
214+
215+
return success();
216+
}
217+
};
218+
219+
class FloatConstantOpLowering
220+
: public OpConversionPattern<cxx::FloatConstantOp> {
221+
public:
222+
using OpConversionPattern::OpConversionPattern;
223+
224+
auto matchAndRewrite(cxx::FloatConstantOp op, OpAdaptor adaptor,
225+
ConversionPatternRewriter &rewriter) const
226+
-> LogicalResult override {
227+
auto typeConverter = getTypeConverter();
228+
auto context = getContext();
229+
230+
auto resultType = typeConverter->convertType(op.getType());
231+
if (!resultType) {
232+
return rewriter.notifyMatchFailure(
233+
op, "failed to convert float constant type");
234+
}
235+
236+
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
237+
adaptor.getValue());
191238

192239
return success();
193240
}
@@ -218,14 +265,60 @@ void CxxToLLVMLoweringPass::runOnOperation() {
218265

219266
// set up the type converter
220267
LLVMTypeConverter typeConverter{context};
268+
269+
typeConverter.addConversion([](cxx::BoolType type) {
270+
// todo: i8/i32 for data and i1 for control flow
271+
return IntegerType::get(type.getContext(), 8);
272+
});
273+
221274
typeConverter.addConversion([](cxx::IntegerType type) {
222275
return IntegerType::get(type.getContext(), type.getWidth());
223276
});
224277

278+
typeConverter.addConversion([](cxx::FloatType type) -> Type {
279+
auto width = type.getWidth();
280+
switch (width) {
281+
case 16:
282+
return Float16Type::get(type.getContext());
283+
case 32:
284+
return Float32Type::get(type.getContext());
285+
case 64:
286+
return Float64Type::get(type.getContext());
287+
default:
288+
return {};
289+
} // switch
290+
});
291+
225292
typeConverter.addConversion([](cxx::PointerType type) {
226293
return LLVM::LLVMPointerType::get(type.getContext());
227294
});
228295

296+
DenseMap<cxx::ClassType, Type> convertedClassTypes;
297+
typeConverter.addConversion([&](cxx::ClassType type) -> Type {
298+
if (auto it = convertedClassTypes.find(type);
299+
it != convertedClassTypes.end()) {
300+
return it->second;
301+
}
302+
303+
auto structType =
304+
LLVM::LLVMStructType::getIdentified(type.getContext(), type.getName());
305+
306+
convertedClassTypes[type] = structType;
307+
308+
SmallVector<Type> fieldTypes;
309+
bool isPacked = false;
310+
311+
for (auto field : type.getBody()) {
312+
auto convertedFieldType = typeConverter.convertType(field);
313+
// todo: check if the field type was converted successfully
314+
fieldTypes.push_back(convertedFieldType);
315+
}
316+
317+
structType.setBody(fieldTypes, isPacked);
318+
319+
return structType;
320+
});
321+
229322
// set up the conversion patterns
230323
ConversionTarget target(*context);
231324

@@ -234,8 +327,9 @@ void CxxToLLVMLoweringPass::runOnOperation() {
234327

235328
RewritePatternSet patterns(context);
236329
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
237-
LoadOpLowering, StoreOpLowering, IntConstantOpLowering>(
238-
typeConverter, context);
330+
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
331+
IntConstantOpLowering, FloatConstantOpLowering>(typeConverter,
332+
context);
239333

240334
populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
241335
typeConverter);

src/parser/cxx/decl.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ struct GetDeclaratorType {
215215
if (auto params = ast->parameterDeclarationClause) {
216216
for (auto it = params->parameterDeclarationList; it; it = it->next) {
217217
auto paramType = it->value->type;
218+
219+
if (control()->is_void(paramType)) {
220+
continue;
221+
}
222+
218223
parameterTypes.push_back(paramType);
219224
}
220225

0 commit comments

Comments
 (0)