diff --git a/lib/vast/Conversion/Parser/Refine.cpp b/lib/vast/Conversion/Parser/Refine.cpp index fae07eb16e..4b4faf63c5 100644 --- a/lib/vast/Conversion/Parser/Refine.cpp +++ b/lib/vast/Conversion/Parser/Refine.cpp @@ -107,6 +107,64 @@ namespace vast::conv { } }; + struct RefineDeclType : operation_conversion_pattern< pr::Decl > + { + using op_t = pr::Decl; + using base = operation_conversion_pattern< op_t >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + + static mlir_type resolve(mlir_type a, mlir_type b) { + if (!a) { + return b; + } + return a == b ? a : pr::MaybeDataType::get(a.getContext()); + } + + static mlir_type assigned_type(mlir_value value) { + if (!value.getDefiningOp()) { + return value.getType(); + } else if (auto cast = dyn_cast< pr::Cast >(value.getDefiningOp())) { + return assigned_type(cast.getOperand()); + } else { + return value.getType(); + } + } + + static mlir_type assigned_type(op_t op) { + mlir_type type = {}; + auto mod = op->template getParentOfType< core::ModuleOp >(); + for (auto use : core::symbol_table::get_symbol_uses(op, mod)) { + if (auto ref = dyn_cast< pr::Ref >(use.getUser())) { + for (auto ref_user : ref->getUsers()) { + if (auto assign = dyn_cast< pr::Assign >(ref_user)) { + type = resolve(type, assigned_type(assign.getValue())); + } else if (isa< mlir::CallOpInterface >(ref_user)) { + return op.getType(); + } + } + } + } + return op.getType(); + } + + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + op->dump(); + rewriter.modifyOpInPlace(op, [&] { op.setType(assigned_type(op)); }); + return mlir::success(); + } + + static void legalize(base_conversion_config &cfg) { + cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) { + return !pr::is_maybedata(op.getType()) + || pr::is_maybedata(assigned_type(op)); + }); + } + }; + template< typename op_t > struct DeadOpElimination : operation_conversion_pattern< op_t > { @@ -144,6 +202,7 @@ namespace vast::conv { NoParseFold< hl::ChooseExprOp >, NoParseFold< hl::BinaryCondOp >, RefineReturn, + RefineDeclType, DefinitionElimination< hl::EnumDeclOp >, DefinitionElimination< hl::StructDeclOp >, DefinitionElimination< hl::UnionDeclOp >,