Skip to content

Commit 1f8befe

Browse files
committed
[Heavy] Fix access to global bindings; Expose current-builder
1 parent a66d7c4 commit 1f8befe

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

heavy/include/heavy/Value.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,7 @@ struct ContextLocal {
20432043
heavy::Value init(heavy::Context& C, heavy::Value Value = nullptr);
20442044
uintptr_t key() const { return reinterpret_cast<uintptr_t>(this); }
20452045
heavy::Value get(heavy::ContextLocalLookup const& C) const;
2046+
heavy::Value get_binding(heavy::ContextLocalLookup const& C) const;
20462047
void set(heavy::ContextLocalLookup& C, heavy::Value Value);
20472048
};
20482049

heavy/lib/Context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,12 @@ heavy::Value ContextLocal::get(heavy::ContextLocalLookup const& C) const {
14471447
return Value;
14481448
}
14491449

1450+
heavy::Value ContextLocal::get_binding(
1451+
heavy::ContextLocalLookup const& C) const {
1452+
heavy::Value Value = C.LookupTable.lookup(key());
1453+
return dyn_cast<heavy::Binding>(Value);
1454+
}
1455+
14501456

14511457
// Module
14521458

heavy/lib/Mlir.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,9 @@ void block_arg(Context& C, ValueRefs Args) {
424424

425425
static void with_builder_impl(Context& C, mlir::OpBuilder const& Builder,
426426
heavy::Value Thunk) {
427-
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
428427
heavy::Value PrevBuilder = C.CreateBinding(heavy::Empty());
429428
heavy::Value NewBuilder = CreateTagged(C, kind::mlir_builder,
430-
mlir::OpBuilder(MLIRContext));
429+
Builder);
431430

432431
heavy::Value Before = C.CreateLambda(
433432
[](heavy::Context& C, heavy::ValueRefs Args) {
@@ -712,6 +711,8 @@ void load_dialect(Context& C, heavy::ValueRefs Args) {
712711
return C.RaiseError("expecting dialect name");
713712

714713
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
714+
// Ensure the registry is up to date.
715+
MLIRContext->appendDialectRegistry(*C.DialectRegistry);
715716
mlir::Dialect* Dialect = MLIRContext->getOrLoadDialect(Name);
716717
if (Dialect == nullptr)
717718
return C.RaiseError(C.CreateString("failed to load dialect: ", Name), {});
@@ -817,6 +818,7 @@ void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
817818
heavy::initModuleNames(C, HEAVY_MLIR_LIB_STR, {
818819
{"create-op", HEAVY_MLIR_VAR(create_op)},
819820
{"%create-op", HEAVY_MLIR_VAR(create_op_impl)},
821+
{"current-builder", HEAVY_MLIR_VAR(current_builder).get_binding(C)},
820822
{"region", HEAVY_MLIR_VAR(region)},
821823
{"entry-block", HEAVY_MLIR_VAR(entry_block)},
822824
{"results", HEAVY_MLIR_VAR(results)},

0 commit comments

Comments
 (0)