@@ -20,7 +20,7 @@ use smallvec::SmallVec;
2020
2121use crate :: attributes:: { self , llfn_attrs_from_instance} ;
2222use crate :: builder:: Builder ;
23- use crate :: context:: CodegenCx ;
23+ use crate :: context:: { CodegenCx , GenericCx , SCx } ;
2424use crate :: llvm:: { self , Attribute , AttributePlace } ;
2525use crate :: type_:: Type ;
2626use crate :: type_of:: LayoutLlvmExt ;
@@ -362,15 +362,15 @@ fn match_intrinsic_signature<'ll>(
362362 ) ;
363363 }
364364
365- if !equate_ty ( cx , rust_return_ty, llvm_return_ty) {
365+ if !cx . equate_ty ( rust_return_ty, llvm_return_ty) {
366366 error ! (
367367 "Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as return type for `{fn_name}`"
368368 ) ;
369369 }
370370 for ( idx, ( & rust_argument_ty, llvm_argument_ty) ) in
371371 iter:: zip ( rust_argument_tys, llvm_argument_tys) . enumerate ( )
372372 {
373- if !equate_ty ( cx , rust_argument_ty, llvm_argument_ty) {
373+ if !cx . equate_ty ( rust_argument_ty, llvm_argument_ty) {
374374 error ! (
375375 "Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as argument {idx} for `{fn_name}`"
376376 ) ;
@@ -380,28 +380,53 @@ fn match_intrinsic_signature<'ll>(
380380 fn_ty
381381}
382382
383- fn equate_ty < ' ll > ( cx : & CodegenCx < ' ll , ' _ > , rust_ty : & ' ll Type , llvm_ty : & ' ll Type ) -> bool {
384- if rust_ty == llvm_ty {
385- return true ;
386- }
387- if cx. type_kind ( llvm_ty) == TypeKind :: X86_AMX && cx. type_kind ( rust_ty) == TypeKind :: Vector {
388- let element_count = cx. vector_length ( rust_ty) ;
389- let element_ty = cx. element_type ( rust_ty) ;
390-
391- let element_size_bits = match cx. type_kind ( element_ty) {
392- TypeKind :: Half => 16 ,
393- TypeKind :: Float => 32 ,
394- TypeKind :: Double => 64 ,
395- TypeKind :: FP128 => 128 ,
396- TypeKind :: Integer => cx. int_width ( element_ty) ,
397- TypeKind :: Pointer => cx. int_width ( cx. isize_ty ) ,
398- _ => bug ! ( "Vector element type `{element_ty:?}` not one of integer, float or pointer" ) ,
399- } ;
400- let vector_size_bits = element_size_bits * element_count as u64 ;
383+ impl < ' ll , CX : Borrow < SCx < ' ll > > > GenericCx < ' ll , CX > {
384+ pub ( crate ) fn equate_ty ( & self , rust_ty : & ' ll Type , llvm_ty : & ' ll Type ) -> bool {
385+ if rust_ty == llvm_ty {
386+ return true ;
387+ }
388+
389+ match self . type_kind ( llvm_ty) {
390+ TypeKind :: X86_AMX if self . type_kind ( rust_ty) == TypeKind :: Vector => {
391+ let element_count = self . vector_length ( rust_ty) ;
392+ let element_ty = self . element_type ( rust_ty) ;
393+
394+ let element_size_bits = match self . type_kind ( element_ty) {
395+ TypeKind :: Half => 16 ,
396+ TypeKind :: Float => 32 ,
397+ TypeKind :: Double => 64 ,
398+ TypeKind :: FP128 => 128 ,
399+ TypeKind :: Integer => self . int_width ( element_ty) ,
400+ TypeKind :: Pointer => self . int_width ( self . isize_ty ( ) ) ,
401+ _ => bug ! (
402+ "Vector element type `{element_ty:?}` not one of integer, float or pointer"
403+ ) ,
404+ } ;
405+ let vector_size_bits = element_size_bits * element_count as u64 ;
406+
407+ vector_size_bits == 8192
408+ }
409+ TypeKind :: BFloat => rust_ty == self . type_i16 ( ) ,
410+ TypeKind :: Vector if self . type_kind ( rust_ty) == TypeKind :: Vector => {
411+ let llvm_element_count = self . vector_length ( llvm_ty) ;
412+ let rust_element_count = self . vector_length ( rust_ty) ;
401413
402- return vector_size_bits == 8192 ;
414+ if llvm_element_count != rust_element_count {
415+ return false ;
416+ }
417+
418+ let llvm_element_ty = self . element_type ( llvm_ty) ;
419+ let rust_element_ty = self . element_type ( rust_ty) ;
420+
421+ if llvm_element_ty == self . type_bf16 ( ) {
422+ rust_element_ty == self . type_i16 ( )
423+ } else {
424+ false
425+ }
426+ }
427+ _ => false ,
428+ }
403429 }
404- return false ;
405430}
406431
407432impl < ' ll , ' tcx > FnAbiLlvmExt < ' ll , ' tcx > for FnAbi < ' tcx , Ty < ' tcx > > {
0 commit comments