@@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager(
653653 // We then run the llvm_optimize function a second time, to optimize the code which we generated
654654 // in the enzyme differentiation pass.
655655 let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
656+ let enable_gpu = true ; //config.offload.contains(&config::Offload::Enable);
656657 let stage = if thin {
657658 write:: AutodiffStage :: PreAD
658659 } else {
@@ -667,6 +668,114 @@ pub(crate) fn run_pass_manager(
667668 write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
668669 }
669670
671+ if cfg ! ( llvm_enzyme) && enable_gpu && !thin {
672+ // first we need to add all the fun to the host module
673+ // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
674+ // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
675+ let cx =
676+ SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
677+ if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
678+ let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
679+ let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
680+ let tptr = cx. type_ptr ( ) ;
681+ let ti64 = cx. type_i64 ( ) ;
682+ let ti32 = cx. type_i32 ( ) ;
683+ let ti16 = cx. type_i16 ( ) ;
684+ let tarr = cx. type_array ( ti32, 3 ) ;
685+
686+ let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
687+ let kernel_elements = vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
688+
689+ cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
690+ cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
691+ let global = cx. declare_global ( "my_struct_global" , offload_entry_ty) ;
692+ let global = cx. declare_global ( "my_struct_global2" , kernel_arguments_ty) ;
693+ dbg ! ( & offload_entry_ty) ;
694+ dbg ! ( & kernel_arguments_ty) ;
695+ //LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
696+ //LLVMStructSetBody(structTy, elements, 9, 0);
697+ dbg ! ( "created struct" ) ;
698+ for num in 0 ..5 {
699+ if !cx. get_function ( & format ! ( "kernel_{num}" ) ) . is_some ( ) {
700+ continue ;
701+ }
702+ //for function in cx.get_functions() {
703+ //if !attributes::has_attr(function, Function, llvm::AttributeKind::OptimizeForSize) {
704+ // dbg!("skipping minsize fnc");
705+ // dbg!(&function);
706+ // // print fnc name
707+ // let enzyme_marker = "minsize";
708+ // if attributes::has_string_attr(function, enzyme_marker) {
709+ // dbg!("found minsize str");
710+ // }
711+ // continue;
712+
713+ let size_name = format ! ( ".offload_sizes.{num}" ) ;
714+ let size_ty = cx. type_array ( ti64, 4 ) ;
715+ //let size_val = vec![8i64,0,16,0];
716+ let c_val_8 = cx. get_const_i64 ( 8 ) ;
717+ let c_val_0 = cx. get_const_i64 ( 0 ) ;
718+ let c_val_16 = cx. get_const_i64 ( 16 ) ;
719+ let size_val = vec ! [ c_val_8, c_val_0, c_val_16, c_val_0] ;
720+
721+ //let val = cx.define_global(&size_name, size_ty).unwrap();
722+ //dbg!(&val);
723+ //let section_var = cx
724+ // .define_global(section_var_name, llvm_type)
725+ // .unwrap_or_else(|| bug!("symbol `{}` is already defined", section_var_name));
726+ //llvm::set_section(section_var, c".debug_gdb_scripts");
727+ //llvm::set_initializer(section_var, cx.const_bytes(section_contents));
728+ //llvm::LLVMSetGlobalConstant(section_var, llvm::True);
729+ //llvm::set_linkage(section_var, llvm::Linkage::LinkOnceODRLinkage);
730+ //// This should make sure that the whole section is not larger than
731+ //// the string it contains. Otherwise we get a warning from GDB.
732+ //llvm::LLVMSetAlignment(section_var, 1);
733+ //llvm::set_initializer(val, cx.const_bytes(size_val.as_slice()));
734+ let initializer = cx. const_array ( ti64, & size_val) ;
735+ let name = format ! ( ".offload_sizes.{num}" ) ;
736+ let c_name = CString :: new ( name) . unwrap ( ) ;
737+ let array = llvm:: add_global ( cx. llmod , cx. val_ty ( initializer) , & c_name ) ;
738+ llvm:: set_global_constant ( array, true ) ;
739+ unsafe { llvm:: LLVMSetUnnamedAddress ( array, llvm:: UnnamedAddr :: Global ) } ;
740+ llvm:: set_linkage ( array, llvm:: Linkage :: PrivateLinkage ) ;
741+ llvm:: set_initializer ( array, initializer) ;
742+ dbg ! ( & array) ;
743+ // 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
744+ // 2. @.offload_maptypes
745+ // 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
746+ // 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
747+ // 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
748+ }
749+ // @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
750+ // @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
751+ // @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
752+ // @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id = weak constant i8 0
753+ // @.offload_sizes.1 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
754+ // @.offload_maptypes.2 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
755+ // @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id = weak constant i8 0
756+ // @.offload_sizes.3 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
757+ // @.offload_maptypes.4 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
758+ // @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
759+ // @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
760+ // @_ZSt4cout = external global %"class.std::basic_ostream", align 8
761+ // @.str = private unnamed_addr constant [3 x i8] c"hi\00", align 1
762+ // @.offload_sizes.7 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
763+ // @.offload_maptypes.8 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
764+ // @.str.9 = private unnamed_addr constant [3 x i8] c"ho\00", align 1
765+ // @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
766+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
767+ // @.offloading.entry_name.10 = internal unnamed_addr constant [67 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13\00", section ".llvm.rodata.offloading", align 1
768+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id, ptr @.offloading.entry_name.10, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
769+ // @.offloading.entry_name.11 = internal unnamed_addr constant [69 x i8] c"__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19\00", section ".llvm.rodata.offloading", align 1
770+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id, ptr @.offloading.entry_name.11, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
771+ // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_zaxpy.cpp, ptr null }]
772+ } else {
773+ dbg ! ( "no marker found" ) ;
774+ }
775+ } else {
776+ dbg ! ( "Not creating struct" ) ;
777+ }
778+
670779 if cfg ! ( llvm_enzyme) && enable_ad && !thin {
671780 let cx =
672781 SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
0 commit comments