@@ -162,6 +162,7 @@ struct vk_device_struct {
162162    uint32_t  subgroup_size;
163163    uint32_t  shader_core_count;
164164    bool  uma;
165+     bool  float_controls_rte_fp16;
165166    bool  coopmat2;
166167
167168    bool  coopmat_support;
@@ -1916,17 +1917,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
19161917    ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512"  , soft_max_f32_f16_len, soft_max_f32_f16_data, " main"  , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512  }, 1 );
19171918
19181919    ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32"  , rope_norm_f32_len, rope_norm_f32_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1919-     ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16"  , rope_norm_f16_len, rope_norm_f16_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1920- 
19211920    ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32"  , rope_neox_f32_len, rope_neox_f32_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1922-     ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16"  , rope_neox_f16_len, rope_neox_f16_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1921+ 
1922+     if  (device->float_controls_rte_fp16 ) {
1923+         ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16"  , rope_norm_f16_rte_len, rope_norm_f16_rte_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1924+         ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16"  , rope_neox_f16_rte_len, rope_neox_f16_rte_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1925+     } else  {
1926+         ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16"  , rope_norm_f16_len, rope_norm_f16_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1927+         ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16"  , rope_neox_f16_len, rope_neox_f16_data, " main"  , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1928+     }
19231929
19241930    ggml_vk_create_pipeline (device, device->pipeline_argsort_f32 , " argsort_f32"  , argsort_f32_len, argsort_f32_data, " main"  , 2 , sizeof (vk_op_argsort_push_constants), {1024 , 1 , 1 }, {}, 1 );
19251931
19261932    ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32"  , sum_rows_f32_len, sum_rows_f32_data, " main"  , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size  }, 1 );
19271933
19281934    ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32"  , im2col_f32_len, im2col_f32_data, " main"  , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1929-     ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16"  , im2col_f32_f16_len, im2col_f32_f16_data, " main"  , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1935+     if  (device->float_controls_rte_fp16 ) {
1936+         ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16"  , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main"  , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1937+     } else  {
1938+         ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16"  , im2col_f32_f16_len, im2col_f32_f16_data, " main"  , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1939+     }
19301940
19311941    ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32"  , timestep_embedding_f32_len, timestep_embedding_f32_data, " main"  , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
19321942
@@ -2007,11 +2017,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
20072017        vk::PhysicalDeviceDriverProperties driver_props;
20082018        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20092019        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2020+         vk::PhysicalDeviceVulkan12Properties vk12_props;
20102021        props2.pNext  = &props3;
20112022        props3.pNext  = &subgroup_props;
20122023        subgroup_props.pNext  = &driver_props;
2024+         driver_props.pNext  = &vk12_props;
20132025
2014-         VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props ;
2026+         VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props ;
20152027
20162028        if  (maintenance4_support) {
20172029            last_struct->pNext  = (VkBaseOutStructure *)&props4;
@@ -2057,6 +2069,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
20572069        } else  {
20582070            device->shader_core_count  = 0 ;
20592071        }
2072+         device->float_controls_rte_fp16  = vk12_props.shaderRoundingModeRTEFloat16 ;
20602073
20612074        const  bool  force_disable_f16 = getenv (" GGML_VK_DISABLE_F16"  ) != nullptr ;
20622075
0 commit comments