Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions lib/vast/Conversion/Parser/Refine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,64 @@ namespace vast::conv {
}
};

struct RefineDeclType : operation_conversion_pattern< pr::Decl >
{
using op_t = pr::Decl;
using base = operation_conversion_pattern< op_t >;
using base::base;

using adaptor_t = typename op_t::Adaptor;

static mlir_type resolve(mlir_type a, mlir_type b) {
if (!a) {
return b;
}
return a == b ? a : pr::MaybeDataType::get(a.getContext());
}

static mlir_type assigned_type(mlir_value value) {
if (!value.getDefiningOp()) {
return value.getType();
} else if (auto cast = dyn_cast< pr::Cast >(value.getDefiningOp())) {
return assigned_type(cast.getOperand());
} else {
return value.getType();
}
}

static mlir_type assigned_type(op_t op) {
mlir_type type = {};
auto mod = op->template getParentOfType< core::ModuleOp >();
for (auto use : core::symbol_table::get_symbol_uses(op, mod)) {
if (auto ref = dyn_cast< pr::Ref >(use.getUser())) {
for (auto ref_user : ref->getUsers()) {
if (auto assign = dyn_cast< pr::Assign >(ref_user)) {
type = resolve(type, assigned_type(assign.getValue()));
} else if (isa< mlir::CallOpInterface >(ref_user)) {
return op.getType();
}
}
}
}
return op.getType();
}

logical_result matchAndRewrite(
op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
) const override {
op->dump();
rewriter.modifyOpInPlace(op, [&] { op.setType(assigned_type(op)); });
return mlir::success();
}

static void legalize(base_conversion_config &cfg) {
cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) {
return !pr::is_maybedata(op.getType())
|| pr::is_maybedata(assigned_type(op));
});
}
};

template< typename op_t >
struct DeadOpElimination : operation_conversion_pattern< op_t >
{
Expand Down Expand Up @@ -144,6 +202,7 @@ namespace vast::conv {
NoParseFold< hl::ChooseExprOp >,
NoParseFold< hl::BinaryCondOp >,
RefineReturn,
RefineDeclType,
DefinitionElimination< hl::EnumDeclOp >,
DefinitionElimination< hl::StructDeclOp >,
DefinitionElimination< hl::UnionDeclOp >,
Expand Down