Skip to content

Commit 8b414c1

Browse files
committed
fix(Vulkan): deduplicate the same device coming from different drivers
1 parent 76aea27 commit 8b414c1

File tree

1 file changed

+93
-2
lines changed

1 file changed

+93
-2
lines changed

llama/gpuInfo/vulkan-gpu-info.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,107 @@
11
#include <stddef.h>
2+
#include <map>
23
#include <vector>
34

45
#include <vulkan/vulkan.hpp>
56

7+
constexpr uint32_t VK_VENDOR_ID_AMD = 0x1002;
8+
constexpr uint32_t VK_VENDOR_ID_APPLE = 0x106b;
9+
constexpr uint32_t VK_VENDOR_ID_INTEL = 0x8086;
10+
constexpr uint32_t VK_VENDOR_ID_NVIDIA = 0x10de;
11+
612
typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message);
713

8-
static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedMemorySize, bool addDeviceNames, std::vector<std::string> * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback, bool * checkSupported) {
14+
static vk::Instance vulkanInstance() {
915
vk::ApplicationInfo appInfo("node-llama-cpp GPU info", 1, "llama.cpp", 1, VK_API_VERSION_1_2);
1016
vk::InstanceCreateInfo createInfo(vk::InstanceCreateFlags(), &appInfo, {}, {});
11-
vk::Instance instance = vk::createInstance(createInfo);
17+
return vk::createInstance(createInfo);
18+
}
1219

20+
static std::vector<vk::PhysicalDevice> dedupedDevices() {
21+
vk::Instance instance = vulkanInstance();
1322
auto physicalDevices = instance.enumeratePhysicalDevices();
23+
std::vector<vk::PhysicalDevice> dedupedDevices;
24+
dedupedDevices.reserve(physicalDevices.size());
25+
26+
// adapted from `ggml_vk_instance_init` in `ggml-vulkan.cpp`
27+
for (const auto& device : physicalDevices) {
28+
vk::PhysicalDeviceProperties2 newProps;
29+
vk::PhysicalDeviceDriverProperties newDriver;
30+
vk::PhysicalDeviceIDProperties newId;
31+
newProps.pNext = &newDriver;
32+
newDriver.pNext = &newId;
33+
device.getProperties2(&newProps);
34+
35+
auto oldDevice = std::find_if(
36+
dedupedDevices.begin(),
37+
dedupedDevices.end(),
38+
[&newId](const vk::PhysicalDevice& oldDevice) {
39+
vk::PhysicalDeviceProperties2 oldProps;
40+
vk::PhysicalDeviceDriverProperties oldDriver;
41+
vk::PhysicalDeviceIDProperties oldId;
42+
oldProps.pNext = &oldDriver;
43+
oldDriver.pNext = &oldId;
44+
oldDevice.getProperties2(&oldProps);
45+
46+
bool equals = std::equal(std::begin(oldId.deviceUUID), std::end(oldId.deviceUUID), std::begin(newId.deviceUUID));
47+
equals |= oldId.deviceLUIDValid && newId.deviceLUIDValid &&
48+
std::equal(std::begin(oldId.deviceLUID), std::end(oldId.deviceLUID), std::begin(newId.deviceLUID));
49+
50+
return equals;
51+
}
52+
);
53+
54+
if (oldDevice == dedupedDevices.end()) {
55+
dedupedDevices.push_back(device);
56+
continue;
57+
}
58+
59+
vk::PhysicalDeviceProperties2 oldProps;
60+
vk::PhysicalDeviceDriverProperties oldDriver;
61+
oldProps.pNext = &oldDriver;
62+
oldDevice->getProperties2(&oldProps);
63+
64+
std::map<vk::DriverId, int> driverPriorities {};
65+
int oldPriority = std::numeric_limits<int>::max();
66+
int newPriority = std::numeric_limits<int>::max();
67+
68+
switch (oldProps.properties.vendorID) {
69+
case VK_VENDOR_ID_AMD:
70+
driverPriorities[vk::DriverId::eMesaRadv] = 1;
71+
driverPriorities[vk::DriverId::eAmdOpenSource] = 2;
72+
driverPriorities[vk::DriverId::eAmdProprietary] = 3;
73+
break;
74+
case VK_VENDOR_ID_INTEL:
75+
driverPriorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
76+
driverPriorities[vk::DriverId::eIntelProprietaryWindows] = 2;
77+
break;
78+
case VK_VENDOR_ID_NVIDIA:
79+
driverPriorities[vk::DriverId::eNvidiaProprietary] = 1;
80+
#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
81+
driverPriorities[vk::DriverId::eMesaNvk] = 2;
82+
#endif
83+
break;
84+
}
85+
driverPriorities[vk::DriverId::eMesaDozen] = 4;
86+
87+
if (driverPriorities.count(oldDriver.driverID)) {
88+
oldPriority = driverPriorities[oldDriver.driverID];
89+
}
90+
if (driverPriorities.count(newDriver.driverID)) {
91+
newPriority = driverPriorities[newDriver.driverID];
92+
}
93+
94+
if (newPriority < oldPriority) {
95+
dedupedDevices.erase(std::remove(dedupedDevices.begin(), dedupedDevices.end(), *oldDevice), dedupedDevices.end());
96+
dedupedDevices.push_back(device);
97+
}
98+
}
99+
100+
return dedupedDevices;
101+
}
102+
103+
static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedMemorySize, bool addDeviceNames, std::vector<std::string> * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback, bool * checkSupported) {
104+
auto physicalDevices = dedupedDevices();
14105

15106
size_t usedMem = 0;
16107
size_t totalMem = 0;

0 commit comments

Comments
 (0)