Skip to content

Commit 96987c0

Browse files
committed
fix: prevent loading Vulkan if the device is unsupported
1 parent a08aaf1 commit 96987c0

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

llama/addon/globals/getGpuInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Napi::Value getGpuType(const Napi::CallbackInfo& info) {
131131

132132
Napi::Value ensureGpuDeviceIsSupported(const Napi::CallbackInfo& info) {
133133
#ifdef GPU_INFO_USE_VULKAN
134-
if (!checkIsVulkanEnvSupported()) {
134+
if (!checkIsVulkanEnvSupported(logVulkanWarning)) {
135135
Napi::Error::New(info.Env(), "Vulkan device is not supported").ThrowAsJavaScriptException();
136136
return info.Env().Undefined();
137137
}

llama/gpuInfo/vulkan-gpu-info.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message);
77

8-
static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedMemorySize, bool addDeviceNames, std::vector<std::string> * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) {
8+
static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedMemorySize, bool addDeviceNames, std::vector<std::string> * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback, bool * checkSupported) {
99
vk::ApplicationInfo appInfo("node-llama-cpp GPU info", 1, "llama.cpp", 1, VK_API_VERSION_1_2);
1010
vk::InstanceCreateInfo createInfo(vk::InstanceCreateFlags(), &appInfo, {}, {});
1111
vk::Instance instance = vk::createInstance(createInfo);
@@ -56,6 +56,24 @@ static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedM
5656
if (size > 0 && addDeviceNames) {
5757
(*deviceNames).push_back(std::string(deviceProps.deviceName.data()));
5858
}
59+
60+
if (checkSupported != nullptr && checkSupported) {
61+
VkPhysicalDeviceFeatures2 device_features2;
62+
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
63+
device_features2.pNext = nullptr;
64+
device_features2.features = (VkPhysicalDeviceFeatures)physicalDevice.getFeatures();
65+
66+
VkPhysicalDeviceVulkan11Features vk11_features;
67+
vk11_features.pNext = nullptr;
68+
vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
69+
device_features2.pNext = &vk11_features;
70+
71+
vkGetPhysicalDeviceFeatures2(physicalDevice, &device_features2);
72+
73+
if (!vk11_features.storageBuffer16BitAccess) {
74+
checkSupported = false;
75+
}
76+
}
5977
}
6078
}
6179
} else {
@@ -78,15 +96,16 @@ static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedM
7896
}
7997

8098
bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, size_t* unifiedMemorySize, gpuInfoVulkanWarningLogCallback_t warningLogCallback) {
81-
return enumerateVulkanDevices(total, used, unifiedMemorySize, false, nullptr, warningLogCallback);
99+
return enumerateVulkanDevices(total, used, unifiedMemorySize, false, nullptr, warningLogCallback, nullptr);
82100
}
83101

84-
bool checkIsVulkanEnvSupported() {
85-
VkPhysicalDeviceVulkan11Features vk11_features;
102+
bool checkIsVulkanEnvSupported(gpuInfoVulkanWarningLogCallback_t warningLogCallback) {
103+
size_t total = 0;
104+
size_t used = 0;
105+
size_t unifiedMemorySize = 0;
86106

87-
if (!vk11_features.storageBuffer16BitAccess) {
88-
return false;
89-
}
107+
bool isSupported = true;
108+
enumerateVulkanDevices(&total, &used, &unifiedMemorySize, false, nullptr, warningLogCallback, &isSupported);
90109

91-
return true;
110+
return isSupported;
92111
}

llama/gpuInfo/vulkan-gpu-info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message);
77

88
bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, size_t* unifiedMemorySize, gpuInfoVulkanWarningLogCallback_t warningLogCallback);
9-
bool checkIsVulkanEnvSupported();
9+
bool checkIsVulkanEnvSupported(gpuInfoVulkanWarningLogCallback_t warningLogCallback);

0 commit comments

Comments
 (0)