@@ -1424,6 +1424,7 @@ struct vk_instance_t {
14241424 PFN_vkGetCalibratedTimestampsEXT pfn_vkGetCalibratedTimestampsEXT = {};
14251425
14261426 std::vector<size_t> device_indices;
1427+ std::vector<bool> device_supports_membudget;
14271428 vk_device devices[GGML_VK_MAX_DEVICES];
14281429};
14291430
@@ -4431,7 +4432,6 @@ static void ggml_vk_instance_init() {
44314432 vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
44324433 vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
44334434 vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
4434-
44354435 }
44364436
44374437 vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
@@ -4448,10 +4448,12 @@ static void ggml_vk_instance_init() {
44484448 }
44494449 }
44504450
4451+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4452+
44514453 // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
44524454 char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
44534455 if (devices_env != nullptr) {
4454- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices() .size();
4456+ size_t num_available_devices = devices .size();
44554457
44564458 std::string devices(devices_env);
44574459 std::replace(devices.begin(), devices.end(), ',', ' ');
@@ -4466,8 +4468,6 @@ static void ggml_vk_instance_init() {
44664468 vk_instance.device_indices.push_back(tmp);
44674469 }
44684470 } else {
4469- std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4470-
44714471 // If no vulkan devices are found, return early
44724472 if (devices.empty()) {
44734473 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
@@ -4572,6 +4572,19 @@ static void ggml_vk_instance_init() {
45724572 GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
45734573
45744574 for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
4575+ vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
4576+ std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
4577+
4578+ bool membudget_supported = false;
4579+ for (const auto & ext : extensionprops) {
4580+ if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {
4581+ membudget_supported = true;
4582+ break;
4583+ }
4584+ }
4585+
4586+ vk_instance.device_supports_membudget.push_back(membudget_supported);
4587+
45754588 ggml_vk_print_gpu_info(i);
45764589 }
45774590}
@@ -11881,15 +11894,29 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
1188111894
1188211895void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
1188311896 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
11897+ GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
1188411898
1188511899 vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
11900+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
11901+ vk::PhysicalDeviceMemoryProperties2 memprops = {};
11902+ bool membudget_supported = vk_instance.device_supports_membudget[device];
11903+
11904+ if (membudget_supported) {
11905+ memprops.pNext = &budgetprops;
11906+ }
11907+ vkdev.getMemoryProperties2(&memprops);
1188611908
11887- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
11909+ for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
11910+ const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
1188811911
11889- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
1189011912 if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
1189111913 *total = heap.size;
11892- *free = heap.size;
11914+
11915+ if (membudget_supported && i < budgetprops.heapUsage.size()) {
11916+ *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
11917+ } else {
11918+ *free = heap.size;
11919+ }
1189311920 break;
1189411921 }
1189511922 }
0 commit comments