Skip to content

Commit eacda0d

Browse files
benvanikclaude
andauthored
Add iree_hal_device_group_t to own device topology lifecycle (iree-org#23576)
The topology matrix from iree-org#23573 needs an owner with a clear lifetime contract. Today devices are passed to the HAL module as a flat array — there is no object that retains the devices, owns the topology, and guarantees the topology pointer remains valid for the duration of execution. Every device holds a raw pointer to the topology it was assigned, so if the topology is freed while devices are still alive, those pointers dangle. `iree_hal_device_group_t` is that owner. It takes already-created devices, builds the immutable topology matrix from their capabilities, pushes topology info into each device, and retains all of them. The group's lifetime brackets the devices': whoever holds the devices long-term (the HAL module, the CTS harness, a Python session) retains the group, and the group retains the devices, so the topology pointer in each device is guaranteed valid. ### Creation API The builder pattern matches the topology builder it wraps: stack-allocate, add devices, finalize. Finalize is a consuming operation — it queries capabilities from all devices, computes edge descriptors, calls driver-specific refinement for same-driver pairs, builds the topology matrix, assigns topology info into each device via the new vtable method, and produces the immutable group. The builder is zeroed after finalize (whether it succeeds or fails) and cannot be reused. For the common single-device case (7 of 9 callers), `iree_hal_device_group_create_from_device` wraps the builder sequence into a one-liner. ### `assign_topology_info` vtable method Devices need to receive their topology info after the matrix is built — the topology doesn't exist yet when the device is created, and the device's index in the matrix isn't known until group creation. This is a new vtable method on `iree_hal_device_t` that the group calls during finalize. All existing driver implementations (local-sync, local-task, CUDA, HIP, Vulkan, Metal, AMDGPU, null) store the info into their device struct. The method is called exactly once per device, during group creation. ### HAL module integration `iree_hal_module_create` now takes a `iree_hal_device_group_t*` instead of `(device_count, devices[])`. The module retains the group and delegates all device access through `iree_hal_device_group_device_count` / `iree_hal_device_group_device_at`. This eliminates the flexible array member from `iree_hal_module_t`, simplifies allocation (fixed-size struct instead of variable-size), and makes the lifetime contract explicit: the module holds the group, the group holds the devices and topology. All callers are updated — CLI tooling, the high-level runtime session, TFLite bindings, Python bindings, PJRT, ConstEval, simple_embedding samples, check_test, and the CTS test harness. ### Testing A mock device (`hal/testing/mock_device`) provides controllable capabilities for testing topology construction without requiring real hardware. The device group tests exercise builder validation (empty builds, duplicate devices, capacity limits), single-device and multi-device group creation, topology correctness (self-edges, cross-device edges with expected interop modes), the convenience function, and lifetime ordering (group outlives devices). The CTS test harness creates a device group in `SetUpTestSuite` so every CTS test runs with topology info assigned. ### Where this is going The device group is the scheduling domain for the causal execution system. When the AMDGPU driver gets its frontier-integrated semaphores and queue operations, the device group's topology matrix is what tells the scheduler whether a semaphore can be waited on natively or needs handle import, whether a buffer can be read directly or needs DMA transfer, and what the relative cost is. The group also becomes the natural attachment point for collective channel creation and multi-device resource pools. --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent a2c8b6b commit eacda0d

File tree

33 files changed

+1646
-122
lines changed

33 files changed

+1646
-122
lines changed

compiler/src/iree/compiler/ConstEval/Runtime.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ CompiledBinary::CompiledBinary() = default;
126126
CompiledBinary::~CompiledBinary() = default;
127127

128128
void CompiledBinary::deinitialize() {
129-
hal_module.reset();
130-
main_module.reset();
129+
halModule.reset();
130+
mainModule.reset();
131131
context.reset();
132132
device.reset();
133133
}
@@ -325,7 +325,7 @@ LogicalResult FunctionCall::invoke(Location loc, StringRef name) {
325325
// Lookup function.
326326
iree_vm_function_t function;
327327
if (auto status = iree_vm_module_lookup_function_by_name(
328-
binary.main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
328+
binary.mainModule.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
329329
iree_string_view_t{name.data(),
330330
static_cast<iree_host_size_t>(name.size())},
331331
&function)) {
@@ -462,29 +462,34 @@ LogicalResult CompiledBinary::initialize(Location loc, void *data,
462462
}
463463
iree_hal_driver_release(driver);
464464

465-
// Create hal module.
465+
// Create device group and hal module.
466+
iree_hal_device_group_t *deviceGroup = nullptr;
466467
if (iree_status_is_ok(status)) {
467-
std::array<iree_hal_device_t *, 1> devices = {device.get()};
468-
status = iree_hal_module_create(
469-
runtime.instance.get(), iree_hal_module_device_policy_default(),
470-
devices.size(), devices.data(), IREE_HAL_MODULE_FLAG_NONE,
471-
iree_hal_module_debug_sink_stdio(stderr), iree_allocator_system(),
472-
&hal_module);
468+
status = iree_hal_device_group_create_from_device(
469+
device.get(), iree_allocator_system(), &deviceGroup);
473470
}
471+
if (iree_status_is_ok(status)) {
472+
status = iree_hal_module_create(runtime.instance.get(),
473+
iree_hal_module_device_policy_default(),
474+
deviceGroup, IREE_HAL_MODULE_FLAG_NONE,
475+
iree_hal_module_debug_sink_stdio(stderr),
476+
iree_allocator_system(), &halModule);
477+
}
478+
iree_hal_device_group_release(deviceGroup);
474479

475480
// Bytecode module.
476481
if (iree_status_is_ok(status)) {
477482
status = iree_vm_bytecode_module_create(
478483
runtime.instance.get(), IREE_VM_BYTECODE_MODULE_FLAG_NONE,
479484
iree_make_const_byte_span(data, length), iree_allocator_null(),
480-
iree_allocator_system(), &main_module);
485+
iree_allocator_system(), &mainModule);
481486
}
482487

483488
// Create context.
484489
if (iree_status_is_ok(status)) {
485490
std::array<iree_vm_module_t *, 2> modules = {
486-
hal_module.get(),
487-
main_module.get(),
491+
halModule.get(),
492+
mainModule.get(),
488493
};
489494
status = iree_vm_context_create_with_modules(
490495
runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),

compiler/src/iree/compiler/ConstEval/Runtime.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class CompiledBinary {
4343
Type mlirType);
4444

4545
iree::vm::ref<iree_hal_device_t> device;
46-
iree::vm::ref<iree_vm_module_t> hal_module;
47-
iree::vm::ref<iree_vm_module_t> main_module;
46+
iree::vm::ref<iree_vm_module_t> halModule;
47+
iree::vm::ref<iree_vm_module_t> mainModule;
4848
iree::vm::ref<iree_vm_context_t> context;
4949

5050
friend class FunctionCall;

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,12 +1680,16 @@ iree_status_t ClientInstance::PopulateVMModules(
16801680
iree_hal_device_t* hal_device,
16811681
iree::vm::ref<iree_vm_module_t>& main_module) {
16821682
// HAL module.
1683+
iree_hal_device_group_t* device_group = nullptr;
1684+
IREE_RETURN_IF_ERROR(iree_hal_device_group_create_from_device(
1685+
hal_device, host_allocator(), &device_group));
16831686
modules.push_back({});
1684-
IREE_RETURN_IF_ERROR(iree_hal_module_create(
1685-
vm_instance(), iree_hal_module_device_policy_default(),
1686-
/*device_count=*/1, &hal_device, IREE_HAL_MODULE_FLAG_NONE,
1687-
iree_hal_module_debug_sink_stdio(stderr), host_allocator(),
1688-
&modules.back()));
1687+
iree_status_t status = iree_hal_module_create(
1688+
vm_instance(), iree_hal_module_device_policy_default(), device_group,
1689+
IREE_HAL_MODULE_FLAG_NONE, iree_hal_module_debug_sink_stdio(stderr),
1690+
host_allocator(), &modules.back());
1691+
iree_hal_device_group_release(device_group);
1692+
IREE_RETURN_IF_ERROR(status);
16891693

16901694
// Main module.
16911695
modules.push_back(main_module);

runtime/bindings/python/hal.cc

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,38 +1107,40 @@ VmModule CreateHalModule(
11071107
PyExc_ValueError,
11081108
"\"device\" and \"devices\" are mutually exclusive arguments.");
11091109
}
1110-
std::vector<iree_hal_device_t*> devices_vector;
1111-
iree_hal_device_t* device_ptr;
1112-
iree_hal_device_t** devices_ptr;
1113-
iree_host_size_t device_count;
1114-
iree_vm_module_t* module = NULL;
1110+
// Build a device group from the provided device(s).
1111+
iree_hal_device_group_builder_t group_builder;
1112+
iree_hal_device_group_builder_initialize(&group_builder);
11151113
if (device) {
1116-
device_ptr = device.value()->raw_ptr();
1117-
devices_ptr = &device_ptr;
1118-
device_count = 1;
1114+
CheckApiStatus(iree_hal_device_group_builder_add_device(
1115+
&group_builder, device.value()->raw_ptr()),
1116+
"Error adding device to group builder");
11191117
} else {
1120-
// Set device related arguments in the case of multiple devices.
1121-
devices_vector.reserve(devices->size());
11221118
for (auto devicesIt = devices->begin(); devicesIt != devices->end();
11231119
++devicesIt) {
1124-
devices_vector.push_back(py::cast<HalDevice*>(*devicesIt)->raw_ptr());
1120+
CheckApiStatus(
1121+
iree_hal_device_group_builder_add_device(
1122+
&group_builder, py::cast<HalDevice*>(*devicesIt)->raw_ptr()),
1123+
"Error adding device to group builder");
11251124
}
1126-
devices_ptr = devices_vector.data();
1127-
device_count = devices_vector.size();
11281125
}
1126+
iree_hal_device_group_t* device_group = nullptr;
1127+
CheckApiStatus(iree_hal_device_group_builder_finalize(
1128+
&group_builder, iree_allocator_system(), &device_group),
1129+
"Error creating device group");
11291130

11301131
iree_hal_module_debug_sink_t iree_hal_module_debug_sink =
11311132
iree_hal_module_debug_sink_stdio(stderr);
11321133
if (debug_sink) {
11331134
iree_hal_module_debug_sink = (*debug_sink)->AsIreeHalModuleDebugSink();
11341135
}
11351136

1136-
CheckApiStatus(
1137-
iree_hal_module_create(
1138-
instance->raw_ptr(), iree_hal_module_device_policy_default(),
1139-
device_count, devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
1140-
iree_hal_module_debug_sink, iree_allocator_system(), &module),
1141-
"Error creating hal module");
1137+
iree_vm_module_t* module = NULL;
1138+
iree_status_t status = iree_hal_module_create(
1139+
instance->raw_ptr(), iree_hal_module_device_policy_default(),
1140+
device_group, IREE_HAL_MODULE_FLAG_NONE, iree_hal_module_debug_sink,
1141+
iree_allocator_system(), &module);
1142+
iree_hal_device_group_release(device_group);
1143+
CheckApiStatus(status, "Error creating hal module");
11421144
VmModule vm_module = VmModule::StealFromRawPtr(module);
11431145
if (debug_sink) {
11441146
// Retain a reference. We want the callback to be valid after

runtime/bindings/tflite/interpreter.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,16 @@ static iree_status_t _TfLiteInterpreterPrepareHAL(
6060
"failed creating the default device for driver '%.*s'",
6161
(int)driver_name.size, driver_name.data);
6262

63-
IREE_RETURN_IF_ERROR(iree_hal_module_create(
63+
iree_hal_device_group_t* device_group = NULL;
64+
IREE_RETURN_IF_ERROR(iree_hal_device_group_create_from_device(
65+
interpreter->device, interpreter->allocator, &device_group));
66+
status = iree_hal_module_create(
6467
interpreter->instance, iree_hal_module_device_policy_default(),
65-
/*device_count=*/1, &interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
68+
device_group, IREE_HAL_MODULE_FLAG_NONE,
6669
iree_hal_module_debug_sink_stdio(stderr), interpreter->allocator,
67-
&interpreter->hal_module));
70+
&interpreter->hal_module);
71+
iree_hal_device_group_release(device_group);
72+
IREE_RETURN_IF_ERROR(status);
6873

6974
return iree_ok_status();
7075
}

runtime/src/iree/hal/BUILD.bazel

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ iree_runtime_cc_library(
4747
"detail.h",
4848
"device.c",
4949
"device.h",
50+
"device_group.c",
51+
"device_group.h",
5052
"driver.c",
5153
"driver.h",
5254
"driver_registry.c",
@@ -118,3 +120,15 @@ iree_runtime_cc_test(
118120
"//runtime/src/iree/testing:gtest_main",
119121
],
120122
)
123+
124+
iree_runtime_cc_test(
125+
name = "device_group_test",
126+
srcs = ["device_group_test.cc"],
127+
deps = [
128+
":hal",
129+
"//runtime/src/iree/base",
130+
"//runtime/src/iree/hal/testing:mock_device",
131+
"//runtime/src/iree/testing:gtest",
132+
"//runtime/src/iree/testing:gtest_main",
133+
],
134+
)

runtime/src/iree/hal/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ iree_cc_library(
4040
"detail.h"
4141
"device.c"
4242
"device.h"
43+
"device_group.c"
44+
"device_group.h"
4345
"driver.c"
4446
"driver.h"
4547
"driver_registry.c"
@@ -111,4 +113,17 @@ iree_cc_test(
111113
iree::testing::gtest_main
112114
)
113115

116+
iree_cc_test(
117+
NAME
118+
device_group_test
119+
SRCS
120+
"device_group_test.cc"
121+
DEPS
122+
::hal
123+
iree::base
124+
iree::hal::testing::mock_device
125+
iree::testing::gtest
126+
iree::testing::gtest_main
127+
)
128+
114129
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###

runtime/src/iree/hal/api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "iree/hal/channel_provider.h" // IWYU pragma: export
1919
#include "iree/hal/command_buffer.h" // IWYU pragma: export
2020
#include "iree/hal/device.h" // IWYU pragma: export
21+
#include "iree/hal/device_group.h" // IWYU pragma: export
2122
#include "iree/hal/driver.h" // IWYU pragma: export
2223
#include "iree/hal/driver_registry.h" // IWYU pragma: export
2324
#include "iree/hal/event.h" // IWYU pragma: export

runtime/src/iree/hal/cts/cts_test_base.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ static iree_status_t TryGetDriver(const std::string& driver_name,
8989
class CTSTestResources {
9090
public:
9191
static iree_hal_driver_t* driver_;
92+
static iree_hal_device_group_t* device_group_;
9293
static iree_hal_device_t* device_;
9394
static iree_hal_allocator_t* device_allocator_;
9495
};
9596
/*static*/ iree_hal_driver_t* CTSTestResources::driver_ = NULL;
97+
/*static*/ iree_hal_device_group_t* CTSTestResources::device_group_ = NULL;
9698
/*static*/ iree_hal_device_t* CTSTestResources::device_ = NULL;
9799
/*static*/ iree_hal_allocator_t* CTSTestResources::device_allocator_ = NULL;
98100

@@ -129,6 +131,13 @@ class CTSTestBase : public BaseType, public CTSTestResources {
129131
IREE_CHECK_OK(status);
130132
device_ = device;
131133

134+
// Create a device group so the device gets topology info assigned.
135+
// The group must outlive the device (it holds a raw topology pointer).
136+
iree_hal_device_group_t* device_group = NULL;
137+
IREE_CHECK_OK(iree_hal_device_group_create_from_device(
138+
device_, iree_allocator_system(), &device_group));
139+
device_group_ = device_group;
140+
132141
device_allocator_ = iree_hal_device_allocator(device_);
133142
iree_hal_allocator_retain(device_allocator_);
134143
}
@@ -142,6 +151,12 @@ class CTSTestBase : public BaseType, public CTSTestResources {
142151
iree_hal_device_release(device_);
143152
device_ = NULL;
144153
}
154+
// Release the device group after the device — the device holds a raw
155+
// pointer to the group's embedded topology.
156+
if (device_group_) {
157+
iree_hal_device_group_release(device_group_);
158+
device_group_ = NULL;
159+
}
145160
if (driver_) {
146161
iree_hal_driver_release(driver_);
147162
driver_ = NULL;

runtime/src/iree/hal/device.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ IREE_API_EXPORT iree_status_t iree_hal_device_refine_topology_edge(
9696
dst_device, edge);
9797
}
9898

99+
IREE_API_EXPORT iree_status_t iree_hal_device_assign_topology_info(
100+
iree_hal_device_t* device,
101+
const iree_hal_device_topology_info_t* topology_info) {
102+
IREE_ASSERT_ARGUMENT(device);
103+
IREE_ASSERT_ARGUMENT(topology_info);
104+
return _VTABLE_DISPATCH(device, assign_topology_info)(device, topology_info);
105+
}
106+
99107
IREE_API_EXPORT iree_hal_semaphore_compatibility_t
100108
iree_hal_device_query_semaphore_compatibility(iree_hal_device_t* device,
101109
iree_hal_semaphore_t* semaphore) {

0 commit comments

Comments
 (0)