Skip to content

Commit a08aaf1

Browse files
committed
fix: prevent loading Vulkan if the device is unsupported
1 parent b3ec8c2 commit a08aaf1

File tree

8 files changed

+30
-1
lines changed

8 files changed

+30
-1
lines changed

llama/addon/addon.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
237237
Napi::PropertyDescriptor::Function("getGpuVramInfo", getGpuVramInfo),
238238
Napi::PropertyDescriptor::Function("getGpuDeviceInfo", getGpuDeviceInfo),
239239
Napi::PropertyDescriptor::Function("getGpuType", getGpuType),
240+
Napi::PropertyDescriptor::Function("ensureGpuDeviceIsSupported", ensureGpuDeviceIsSupported),
240241
Napi::PropertyDescriptor::Function("getSwapInfo", getSwapInfo),
241242
Napi::PropertyDescriptor::Function("getMemoryInfo", getMemoryInfo),
242243
Napi::PropertyDescriptor::Function("loadBackends", addonLoadBackends),

llama/addon/globals/getGpuInfo.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,14 @@ Napi::Value getGpuType(const Napi::CallbackInfo& info) {
128128

129129
return info.Env().Undefined();
130130
}
131+
132+
Napi::Value ensureGpuDeviceIsSupported(const Napi::CallbackInfo& info) {
133+
#ifdef GPU_INFO_USE_VULKAN
134+
if (!checkIsVulkanEnvSupported()) {
135+
Napi::Error::New(info.Env(), "Vulkan device is not supported").ThrowAsJavaScriptException();
136+
return info.Env().Undefined();
137+
}
138+
#endif
139+
140+
return info.Env().Undefined();
141+
}

llama/addon/globals/getGpuInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info);
88
Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info);
99
std::pair<ggml_backend_dev_t, std::string> getGpuDevice();
10-
Napi::Value getGpuType(const Napi::CallbackInfo& info);
10+
Napi::Value getGpuType(const Napi::CallbackInfo& info);
11+
Napi::Value ensureGpuDeviceIsSupported(const Napi::CallbackInfo& info);

llama/gpuInfo/vulkan-gpu-info.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,13 @@ static bool enumerateVulkanDevices(size_t* total, size_t* used, size_t* unifiedM
8080
bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, size_t* unifiedMemorySize, gpuInfoVulkanWarningLogCallback_t warningLogCallback) {
8181
return enumerateVulkanDevices(total, used, unifiedMemorySize, false, nullptr, warningLogCallback);
8282
}
83+
84+
bool checkIsVulkanEnvSupported() {
85+
VkPhysicalDeviceVulkan11Features vk11_features;
86+
87+
if (!vk11_features.storageBuffer16BitAccess) {
88+
return false;
89+
}
90+
91+
return true;
92+
}

llama/gpuInfo/vulkan-gpu-info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message);
77

88
bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, size_t* unifiedMemorySize, gpuInfoVulkanWarningLogCallback_t warningLogCallback);
9+
bool checkIsVulkanEnvSupported();

src/bindings/AddonTypes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ export type BindingModule = {
7272
deviceNames: string[]
7373
},
7474
getGpuType(): "cuda" | "vulkan" | "metal" | false | undefined,
75+
ensureGpuDeviceIsSupported(): void,
7576
getSwapInfo(): {
7677
total: number,
7778
maxSize: number,

src/bindings/Llama.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ export class Llama {
108108
if (loadedGpu == null || (loadedGpu === false && buildGpu !== false))
109109
bindings.loadBackends(path.dirname(bindingPath));
110110

111+
bindings.ensureGpuDeviceIsSupported();
112+
111113
this._gpu = bindings.getGpuType() ?? false;
112114
this._supportsGpuOffloading = bindings.getSupportsGpuOffloading();
113115
this._supportsMmap = bindings.getSupportsMmap();

src/bindings/utils/testBindingBinary.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ if (process.env.TEST_BINDING_CP === "true" && (process.parentPort != null || pro
211211
if (gpuType !== message.gpu)
212212
throw new Error(`Binary GPU type mismatch. Expected: ${message.gpu}, got: ${gpuType}`);
213213

214+
binding.ensureGpuDeviceIsSupported();
215+
214216
sendMessage({type: "done"});
215217
} catch (err) {
216218
console.error(err);

0 commit comments

Comments
 (0)