@@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) {
609609 return *reinterpret_cast <float *>(&val_bits);
610610}
611611
612- uint16_t f8_e4m3_to_f16 (uint8_t f8 ) {
613- // do we need to support uz?
614-
615- const uint32_t exponent_bias = 7 ;
616- if (f8 == 0xff ) {
617- return ggml_fp32_to_fp16 (-NAN);
618- } else if (f8 == 0x7f ) {
619- return ggml_fp32_to_fp16 (NAN);
612+ uint16_t f8_e3m4_to_f16 (uint8_t fp8) {
613+ if ((fp8 & 0x7F ) == 0 || (fp8 & 0x7F ) == 0x7F ) {
614+ // +/- 0 or NaN
615+ return static_cast <uint16_t >(fp8) << 8 ;
620616 }
617+ const uint8_t exponent_bias = 0x3 ; // 2^(3-1)-1
618+ const uint8_t f16_bias = 0xF ; // 2^(5-1)-1
619+ const int mantissa_bits = 4 ;
620+ const uint8_t mantissa_max = 0xF ; // 2^4-1
621621
622- uint32_t sign = f8 & 0x80 ;
623- uint32_t exponent = (f8 & 0x78 ) >> 3 ;
624- uint32_t mantissa = f8 & 0x07 ;
625- uint32_t result = sign << 24 ;
626- if (exponent == 0 ) {
627- if (mantissa > 0 ) {
628- exponent = 0x7f - exponent_bias;
629-
630- // yes, 2 times
631- if ((mantissa & 0x04 ) == 0 ) {
632- mantissa &= 0x03 ;
633- mantissa <<= 1 ;
634- exponent -= 1 ;
635- }
636- if ((mantissa & 0x04 ) == 0 ) {
637- mantissa &= 0x03 ;
638- mantissa <<= 1 ;
639- exponent -= 1 ;
640- }
622+ uint8_t sign = (fp8 >> 7 ) & 0x1 ;
623+ uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
624+ uint8_t mantissa = fp8 & mantissa_max;
641625
642- result |= (mantissa & 0x03 ) << 21 ;
643- result |= exponent << 23 ;
626+ uint16_t fp16_sign = sign << 15 ;
627+ uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
628+ if (exponent == 0 ) {
629+ // subnormal numbers
630+ fp16_exponent++;
631+ // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
632+ while (!(mantissa >> mantissa_bits)) {
633+ mantissa <<= 1 ;
634+ fp16_exponent--;
644635 }
645- } else {
646- result |= mantissa << 20 ;
647- exponent += 0x7f - exponent_bias;
648- result |= exponent << 23 ;
636+ mantissa &= mantissa_max;
649637 }
638+ uint16_t fp16_mantissa = mantissa << 6 ;
650639
651- return ggml_fp32_to_fp16 (* reinterpret_cast < const float *>(&result)) ;
640+ return fp16_sign | fp16_exponent << 10 | fp16_mantissa ;
652641}
653642
654- uint16_t f8_e5m2_to_f16 (uint8_t fp8) {
655- uint8_t sign = (fp8 >> 7 ) & 0x1 ;
656- uint8_t exponent = (fp8 >> 2 ) & 0x1F ;
657- uint8_t mantissa = fp8 & 0x3 ;
658-
659- uint16_t fp16_sign = sign << 15 ;
660- uint16_t fp16_exponent;
661- uint16_t fp16_mantissa;
662-
663- if (exponent == 0 && mantissa == 0 ) { // zero
664- return fp16_sign;
643+ uint16_t f8_e4m3_to_f16 (uint8_t fp8) {
644+ // do we need to support uz?
645+ if ((fp8 & 0x7F ) == 0 || (fp8 & 0x7F ) == 0x7F ) {
646+ // +/- 0 or NaN
647+ return static_cast <uint16_t >(fp8) << 8 ;
665648 }
649+ const uint8_t exponent_bias = 0x7 ; // 2^(4-1)-1
650+ const uint8_t f16_bias = 0xF ; // 2^(5-1)-1
651+ const int mantissa_bits = 3 ;
652+ const uint8_t mantissa_max = 0x7 ; // 2^3-1
666653
667- if (exponent == 0x1F ) { // NAN and INF
668- fp16_exponent = 0x1F ;
669- fp16_mantissa = mantissa ? (mantissa << 8 ) : 0 ;
670- return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
671- }
654+ uint8_t sign = (fp8 >> 7 ) & 0x1 ;
655+ uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
656+ uint8_t mantissa = fp8 & mantissa_max;
672657
673- if (exponent == 0 ) { // subnormal numbers
674- fp16_exponent = 0 ;
675- fp16_mantissa = (mantissa << 8 );
676- return fp16_sign | fp16_mantissa;
658+ uint16_t fp16_sign = sign << 15 ;
659+ uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
660+ if (exponent == 0 ) {
661+ // subnormal numbers
662+ fp16_exponent++;
663+ // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
664+ while (!(mantissa >> mantissa_bits)) {
665+ mantissa <<= 1 ;
666+ fp16_exponent--;
667+ }
668+ mantissa &= mantissa_max;
677669 }
670+ uint16_t fp16_mantissa = mantissa << 7 ;
678671
679- // normal numbers
680- int16_t true_exponent = (int16_t )exponent - 15 + 15 ;
681- if (true_exponent <= 0 ) {
682- fp16_exponent = 0 ;
683- fp16_mantissa = (mantissa << 8 );
684- } else if (true_exponent >= 0x1F ) {
685- fp16_exponent = 0x1F ;
686- fp16_mantissa = 0 ;
687- } else {
688- fp16_exponent = (uint16_t )true_exponent;
689- fp16_mantissa = mantissa << 8 ;
690- }
672+ return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
673+ }
691674
692- return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
675+ uint16_t f8_e5m2_to_f16 (uint8_t fp8) {
676+ // do we need to support fnuz?
677+ return static_cast <uint16_t >(fp8) << 8 ;
693678}
694679
695680void bf16_to_f32_vec (uint16_t * src, float * dst, int64_t n) {
@@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
699684 }
700685}
701686
687+ void f8_e3m4_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
688+ // support inplace op
689+ for (int64_t i = n - 1 ; i >= 0 ; i--) {
690+ dst[i] = f8_e3m4_to_f16 (src[i]);
691+ }
692+ }
693+
702694void f8_e4m3_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
703695 // support inplace op
704696 for (int64_t i = n - 1 ; i >= 0 ; i--) {
@@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
946938 ttype = GGML_TYPE_F32;
947939 } else if (dtype == " F32" ) {
948940 ttype = GGML_TYPE_F32;
941+ } else if (dtype == " F8_E3M4" ) {
942+ ttype = GGML_TYPE_F16;
949943 } else if (dtype == " F8_E4M3" ) {
950944 ttype = GGML_TYPE_F16;
951945 } else if (dtype == " F8_E5M2" ) {
@@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10591053 if (dtype == " BF16" ) {
10601054 tensor_storage.is_bf16 = true ;
10611055 GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1056+ } else if (dtype == " F8_E3M4" ) {
1057+ tensor_storage.is_f8_e3m4 = true ;
1058+ // f8 -> f16
1059+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
10621060 } else if (dtype == " F8_E4M3" ) {
10631061 tensor_storage.is_f8_e4m3 = true ;
10641062 // f8 -> f16
@@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() {
14611459 TensorStorage token_embedding_weight, input_block_weight;
14621460 bool input_block_checked = false ;
14631461
1464- bool has_multiple_encoders = false ;
1465- bool is_unet = false ;
1462+ bool has_multiple_encoders = false ;
1463+ bool is_unet = false ;
14661464
1467- bool is_xl = false ;
1465+ bool is_xl = false ;
14681466 bool is_flux = false ;
14691467
14701468#define found_family (is_xl || is_flux)
@@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() {
14811479 }
14821480 if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
14831481 is_unet = true ;
1484- if (has_multiple_encoders){
1482+ if (has_multiple_encoders) {
14851483 is_xl = true ;
14861484 if (input_block_checked) {
14871485 break ;
@@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() {
14901488 }
14911489 if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
14921490 has_multiple_encoders = true ;
1493- if (is_unet){
1491+ if (is_unet) {
14941492 is_xl = true ;
14951493 if (input_block_checked) {
14961494 break ;
@@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17791777 if (tensor_storage.is_bf16 ) {
17801778 // inplace op
17811779 bf16_to_f32_vec ((uint16_t *)dst_tensor->data , (float *)dst_tensor->data , tensor_storage.nelements ());
1780+ } else if (tensor_storage.is_f8_e3m4 ) {
1781+ // inplace op
1782+ f8_e3m4_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
17821783 } else if (tensor_storage.is_f8_e4m3 ) {
17831784 // inplace op
17841785 f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
@@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17931794 if (tensor_storage.is_bf16 ) {
17941795 // inplace op
17951796 bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1797+ } else if (tensor_storage.is_f8_e3m4 ) {
1798+ // inplace op
1799+ f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
17961800 } else if (tensor_storage.is_f8_e4m3 ) {
17971801 // inplace op
17981802 f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
@@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
18111815 if (tensor_storage.is_bf16 ) {
18121816 // inplace op
18131817 bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1818+ } else if (tensor_storage.is_f8_e3m4 ) {
1819+ // inplace op
1820+ f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
18141821 } else if (tensor_storage.is_f8_e4m3 ) {
18151822 // inplace op
18161823 f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
0 commit comments