@@ -52,7 +52,30 @@ class ur_legacy_sink : public logger::Sink {
5252 };
5353};
5454
55- ur_result_t initPlatforms (PlatformVec &platforms) noexcept try {
55+ // Find the corresponding ZesDevice Handle for a given ZeDevice
56+ ur_result_t getZesDeviceHandle (zes_uuid_t coreDeviceUuid,
57+ zes_device_handle_t *ZesDevice,
58+ uint32_t *SubDeviceId, ze_bool_t *SubDevice) {
59+ uint32_t ZesDriverCount = 0 ;
60+ std::vector<zes_driver_handle_t > ZesDrivers;
61+ std::vector<zes_device_handle_t > ZesDevices;
62+ ze_result_t ZesResult = ZE_RESULT_ERROR_INVALID_ARGUMENT;
63+ ZE2UR_CALL (zesDriverGet, (&ZesDriverCount, nullptr ));
64+ ZesDrivers.resize (ZesDriverCount);
65+ ZE2UR_CALL (zesDriverGet, (&ZesDriverCount, ZesDrivers.data ()));
66+ for (uint32_t I = 0 ; I < ZesDriverCount; ++I) {
67+ ZesResult = ZE_CALL_NOCHECK (
68+ zesDriverGetDeviceByUuidExp,
69+ (ZesDrivers[I], coreDeviceUuid, ZesDevice, SubDevice, SubDeviceId));
70+ if (ZesResult == ZE_RESULT_SUCCESS) {
71+ return UR_RESULT_SUCCESS;
72+ }
73+ }
74+ return UR_RESULT_ERROR_INVALID_ARGUMENT;
75+ }
76+
77+ ur_result_t initPlatforms (PlatformVec &platforms,
78+ ze_result_t ZesResult) noexcept try {
5679 uint32_t ZeDriverCount = 0 ;
5780 ZE2UR_CALL (zeDriverGet, (&ZeDriverCount, nullptr ));
5881 if (ZeDriverCount == 0 ) {
@@ -65,24 +88,37 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
6588
6689 ZE2UR_CALL (zeDriverGet, (&ZeDriverCount, ZeDrivers.data ()));
6790 for (uint32_t I = 0 ; I < ZeDriverCount; ++I) {
91+ bool DriverInit = false ;
6892 ze_device_properties_t device_properties{};
6993 device_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
7094 uint32_t ZeDeviceCount = 0 ;
7195 ZE2UR_CALL (zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, nullptr ));
7296 ZeDevices.resize (ZeDeviceCount);
7397 ZE2UR_CALL (zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, ZeDevices.data ()));
98+ auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
7499 // Check if this driver has GPU Devices
75100 for (uint32_t D = 0 ; D < ZeDeviceCount; ++D) {
76101 ZE2UR_CALL (zeDeviceGetProperties, (ZeDevices[D], &device_properties));
77-
78102 if (ZE_DEVICE_TYPE_GPU == device_properties.type ) {
79- // If this Driver is a GPU, save it as a usable platform.
80- auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
81- UR_CALL (platform->initialize ());
103+ if (!DriverInit) {
104+ // If this Driver is a GPU, save it as a usable platform.
105+ UR_CALL (platform->initialize ());
82106
83- // Save a copy in the cache for future uses.
84- platforms.push_back (std::move (platform));
85- break ;
107+ // Save a copy in the cache for future uses.
108+ platforms.push_back (std::move (platform));
109+ DriverInit = true ;
110+ }
111+ if (ZesResult == ZE_RESULT_SUCCESS) {
112+ ur_zes_device_handle_data_t ZesDeviceData;
113+ zes_uuid_t ZesUUID;
114+ std::memcpy (&ZesUUID, &device_properties.uuid , sizeof (zes_uuid_t ));
115+ if (getZesDeviceHandle (
116+ ZesUUID, &ZesDeviceData.ZesDevice , &ZesDeviceData.SubDeviceId ,
117+ &ZesDeviceData.SubDevice ) == UR_RESULT_SUCCESS) {
118+ platforms.back ()->ZedeviceToZesDeviceMap .insert (
119+ std::make_pair (ZeDevices[D], std::move (ZesDeviceData)));
120+ }
121+ }
86122 }
87123 }
88124 }
@@ -172,7 +208,9 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
172208 return ;
173209 }
174210
175- ur_result_t err = initPlatforms (platforms);
211+ GlobalAdapter->ZesResult = ZE_CALL_NOCHECK (zesInit, (0 ));
212+
213+ ur_result_t err = initPlatforms (platforms, *GlobalAdapter->ZesResult );
176214 if (err == UR_RESULT_SUCCESS) {
177215 result = std::move (platforms);
178216 } else {
0 commit comments