From f70c6f4c3e35df56aff94414dd0108fd755418c0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 21 Aug 2025 12:01:32 -0700 Subject: [PATCH 1/4] fix host code --- .../src/builder/gpu_offload.rs | 141 +++++++++++++----- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 + tests/codegen-llvm/gpu_offload/gpu_host.rs | 70 ++++++--- 3 files changed, 151 insertions(+), 63 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 1280ab1442a09..eae9034e3c608 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -16,23 +16,41 @@ pub(crate) fn handle_gpu_code<'ll>( cx: &'ll SimpleCx<'_>, ) { // The offload memory transfer type for each kernel - let mut o_types = vec![]; - let mut kernels = vec![]; + let mut memtransfer_types = vec![]; + let mut region_ids = vec![]; let offload_entry_ty = add_tgt_offload_entry(&cx); for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { - o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num)); - kernels.push(kernel); + let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num); + memtransfer_types.push(o); + region_ids.push(k); } } - gen_call_handling(&cx, &kernels, &o_types); + gen_call_handling(&cx, &memtransfer_types, ®ion_ids); +} + +// ; Function Attrs: nounwind +// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2 +fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) { + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let args = vec![tptr, ti64, ti32, ti32, tptr, tptr]; + let tgt_fn_ty = cx.type_func(&args, ti32); + let name = "__tgt_target_kernel"; + let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty); + let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx); + attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]); + (tgt_decl, tgt_fn_ty) } // What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper: // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be +// offloaded was defined. fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 let unknown_txt = ";unknown;unknown;0;0;;"; @@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty offload_entry_ty } -fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { +fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); let tptr = cx.type_ptr(); let ti64 = cx.type_i64(); @@ -107,7 +125,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. // uint64_t Unused : 62; - // } Flags = {0, 0, 0}; + // } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte // // The number of teams (for x,y,z dimension). // uint32_t NumTeams[3] = {0, 0, 0}; // // The number of threads (for x,y,z dimension). @@ -118,9 +136,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); - // For now we don't handle kernels, so for now we just add a global dummy - // to make sure that the __tgt_offload_entry is defined and handled correctly. - cx.declare_global("my_struct_global2", kernel_arguments_ty); + kernel_arguments_ty } fn gen_tgt_data_mappers<'ll>( @@ -187,7 +203,7 @@ fn gen_define_handling<'ll>( kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, num: i64, -) -> &'ll llvm::Value { +) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. @@ -205,10 +221,14 @@ fn gen_define_handling<'ll>( // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. - let o_types = - add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]); + // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). + let memtransfer_types = add_priv_unnamed_arr( + &cx, + &format!(".offload_maptypes.{num}"), + &vec![1 + 2 + 32; num_ptr_types], + ); // Next: For each function, generate these three entries. A weak constant, - // the llvm.rodata entry name, and the omp_offloading_entries value + // the llvm.rodata entry name, and the llvm_offload_entries value let name = format!(".kernel_{num}.region_id"); let initializer = cx.get_const_i8(0); @@ -242,13 +262,13 @@ fn gen_define_handling<'ll>( llvm::set_global_constant(llglobal, true); llvm::set_linkage(llglobal, WeakAnyLinkage); llvm::set_initializer(llglobal, initializer); - llvm::set_alignment(llglobal, Align::ONE); - let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_alignment(llglobal, Align::EIGHT); + let c_section_name = CString::new("llvm_offload_entries").unwrap(); llvm::set_section(llglobal, &c_section_name); - o_types + (memtransfer_types, region_id) } -fn declare_offload_fn<'ll>( +pub(crate) fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, ty: &'ll llvm::Type, @@ -285,9 +305,10 @@ fn declare_offload_fn<'ll>( // 6. generate __tgt_target_data_end calls to move data from the GPU fn gen_call_handling<'ll>( cx: &'ll SimpleCx<'_>, - _kernels: &[&'ll llvm::Value], - o_types: &[&'ll llvm::Value], + memtransfer_types: &[&'ll llvm::Value], + region_ids: &[&'ll llvm::Value], ) { + let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } let tptr = cx.type_ptr(); let ti32 = cx.type_i32(); @@ -295,7 +316,7 @@ fn gen_call_handling<'ll>( let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); - gen_tgt_kernel_global(&cx); + let tgt_kernel_decl = gen_tgt_kernel_global(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); let main_fn = cx.get_function("main"); @@ -329,35 +350,32 @@ fn gen_call_handling<'ll>( // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16. let ty2 = cx.type_array(cx.type_i64(), num_args); let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes"); + + //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 + let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args"); + + // Step 1) + unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; + builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); + // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); - for (index, in_ty) in types.iter().enumerate() { - // get function arg, store it into the alloca, and read it. - let p = llvm::get_param(called, index as u32); - let name = llvm::get_value_name(p); - let name = str::from_utf8(&name).unwrap(); - let arg_name = format!("{name}.addr"); - let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name); - - builder.store(p, alloca, Align::EIGHT); - let val = builder.load(in_ty, alloca, Align::EIGHT); - let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]); - vals.push(val); + for index in 0..types.len() { + let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() }; + let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); + vals.push(v); geps.push(gep); } - // Step 1) - unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; - builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); - let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void()); let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty); let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty); let init_ty = cx.type_func(&[], cx.type_void()); let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty); + // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function. // call void @__tgt_register_lib(ptr noundef %6) builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None); // call void @__tgt_init_all_rtls() @@ -415,22 +433,63 @@ fn gen_call_handling<'ll>( // Step 2) let s_ident_t = generate_at_one(&cx); - let o = o_types[0]; + let o = memtransfer_types[0]; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); // Step 3) - // Here we will add code for the actual kernel launches in a follow-up PR. - // FIXME(offload): launch kernels + let mut values = vec![]; + let offload_version = cx.get_const_i32(3); + values.push((4, offload_version)); + values.push((4, cx.get_const_i32(num_args))); + values.push((8, geps.0)); + values.push((8, geps.1)); + values.push((8, geps.2)); + values.push((8, memtransfer_types[0])); + // The next two are debug infos. FIXME(offload) set them + values.push((8, cx.const_null(cx.type_ptr()))); + values.push((8, cx.const_null(cx.type_ptr()))); + values.push((8, cx.get_const_i64(0))); + values.push((8, cx.get_const_i64(0))); + let ti32 = cx.type_i32(); + let ci32_0 = cx.get_const_i32(0); + values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); + values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); + values.push((4, cx.get_const_i32(0))); + + for (i, value) in values.iter().enumerate() { + let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); + builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap()); + } + + let args = vec![ + s_ident_t, + // MAX == -1 + cx.get_const_i64(u64::MAX), + cx.get_const_i32(2097152), + cx.get_const_i32(256), + region_ids[0], + a5, + ]; + let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); + // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) + unsafe { + let next = llvm::LLVMGetNextInstruction(offload_success).unwrap(); + llvm::LLVMRustPositionAfter(builder.llbuilder, next); + llvm::LLVMInstructionEraseFromParent(next); + } // Step 4) - unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; + //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); + drop(builder); + unsafe { llvm::LLVMDeleteFunction(called) }; + // With this we generated the following begin and end mappers. We could easily generate the // update mapper in an update. // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 2461f70a86e35..5dead7f4e7ee5 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1201,6 +1201,7 @@ unsafe extern "C" { // Operations on functions pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); + pub(crate) fn LLVMDeleteFunction(Fn: &Value); // Operations about llvm intrinsics pub(crate) fn LLVMLookupIntrinsicID(Name: *const c_char, NameLen: size_t) -> c_uint; @@ -1230,6 +1231,8 @@ unsafe extern "C" { pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; + pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>; + pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value); // Operations on call sites pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index 513e27426bc0e..fac4054d1b7ff 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -21,16 +21,15 @@ fn main() { } // CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } -// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } // CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } +// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] -// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 35] // CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 // CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 -// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 -// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments +// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 // CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 @@ -43,34 +42,61 @@ fn main() { // CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 -// CHECK-NEXT: %x.addr = alloca ptr, align 8 -// CHECK-NEXT: store ptr %x, ptr %x.addr, align 8 -// CHECK-NEXT: %1 = load ptr, ptr %x.addr, align 8 -// CHECK-NEXT: %2 = getelementptr inbounds float, ptr %1, i32 0 +// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 // CHECK: call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) +// CHECK-NEXT: %1 = getelementptr inbounds float, ptr %x, i32 0 // CHECK-NEXT: call void @__tgt_register_lib(ptr %EmptyDesc) // CHECK-NEXT: call void @__tgt_init_all_rtls() -// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %2 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %x, ptr %2, align 8 +// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 // CHECK-NEXT: store ptr %1, ptr %3, align 8 -// CHECK-NEXT: %4 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: store ptr %2, ptr %4, align 8 -// CHECK-NEXT: %5 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: store i64 1024, ptr %5, align 8 -// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null) -// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) -// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %4 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: store i64 1024, ptr %4, align 8 +// CHECK-NEXT: %5 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %7 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %5, ptr %6, ptr %7, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %8 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0 +// CHECK-NEXT: store i32 3, ptr %8, align 4 +// CHECK-NEXT: %9 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1 +// CHECK-NEXT: store i32 1, ptr %9, align 4 +// CHECK-NEXT: %10 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2 +// CHECK-NEXT: store ptr %5, ptr %10, align 8 +// CHECK-NEXT: %11 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3 +// CHECK-NEXT: store ptr %6, ptr %11, align 8 +// CHECK-NEXT: %12 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4 +// CHECK-NEXT: store ptr %7, ptr %12, align 8 +// CHECK-NEXT: %13 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5 +// CHECK-NEXT: store ptr @.offload_maptypes.1, ptr %13, align 8 +// CHECK-NEXT: %14 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6 +// CHECK-NEXT: store ptr null, ptr %14, align 8 +// CHECK-NEXT: %15 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7 +// CHECK-NEXT: store ptr null, ptr %15, align 8 +// CHECK-NEXT: %16 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8 +// CHECK-NEXT: store i64 0, ptr %16, align 8 +// CHECK-NEXT: %17 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9 +// CHECK-NEXT: store i64 0, ptr %17, align 8 +// CHECK-NEXT: %18 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10 +// CHECK-NEXT: store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %18, align 4 +// CHECK-NEXT: %19 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11 +// CHECK-NEXT: store [3 x i32] [i32 256, i32 0, i32 0], ptr %19, align 4 +// CHECK-NEXT: %20 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12 +// CHECK-NEXT: store i32 0, ptr %20, align 4 +// CHECK-NEXT: %21 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) +// CHECK-NEXT: %22 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %23 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %24 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %22, ptr %23, ptr %24, ptr @.offload_maptypes.1, ptr null, ptr null) // CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc) // CHECK: store ptr %x, ptr %0, align 8 // CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) // CHECK: ret void // CHECK-NEXT: } +// CHECK: Function Attrs: nounwind +// CHECK: declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) + #[unsafe(no_mangle)] #[inline(never)] pub fn kernel_1(x: &mut [f32; 256]) { From 0f05703ed77b31dc4045b9a1bdedad4818fe04a0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Aug 2025 15:17:35 -0700 Subject: [PATCH 2/4] model offload C++ structs through Rust structs --- .../src/builder/gpu_offload.rs | 171 ++++++++++-------- 1 file changed, 96 insertions(+), 75 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index eae9034e3c608..559180de3fe55 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>( // The offload memory transfer type for each kernel let mut memtransfer_types = vec![]; let mut region_ids = vec![]; - let offload_entry_ty = add_tgt_offload_entry(&cx); + let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { @@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm // FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be // offloaded was defined. fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { - // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 let unknown_txt = ";unknown;unknown;0;0;;"; let c_entry_name = CString::new(unknown_txt).unwrap(); let c_val = c_entry_name.as_bytes_with_nul(); @@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { at_one } -pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { - let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let ti16 = cx.type_i16(); - // For each kernel to run on the gpu, we will later generate one entry of this type. - // copied from LLVM - // typedef struct { +struct TgtOffloadEntry { // uint64_t Reserved; // uint16_t Version; // uint16_t Kind; @@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty // uint64_t Size; Size of the entry info (0 if it is a function) // uint64_t Data; // void *AuxAddr; - // } __tgt_offload_entry; - let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; - cx.set_struct_body(offload_entry_ty, &entry_elements, false); - offload_entry_ty } -fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { - let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let tarr = cx.type_array(ti32, 3); +impl TgtOffloadEntry { + pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + // For each kernel to run on the gpu, we will later generate one entry of this type. + // copied from LLVM + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + offload_entry_ty + } + + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + region_id: &'ll Value, + llglobal: &'ll Value, + ) -> Vec<&'ll Value> { + let reserved = cx.get_const_i64(0); + let version = cx.get_const_i16(1); + let kind = cx.get_const_i16(1); + let flags = cx.get_const_i32(0); + let size = cx.get_const_i64(0); + let data = cx.get_const_i64(0); + let aux_addr = cx.const_null(cx.type_ptr()); + vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] + } +} - // Taken from the LLVM APITypes.h declaration: - //struct KernelArgsTy { +// Taken from the LLVM APITypes.h declaration: +struct KernelArgsTy { // uint32_t Version = 0; // Version of this struct for ABI compatibility. // uint32_t NumArgs = 0; // Number of arguments in each input pointer. // void **ArgBasePtrs = @@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { // void **ArgNames = nullptr; // Name of the data for debugging, possibly null. // void **ArgMappers = nullptr; // User-defined mappers, possibly null. // uint64_t Tripcount = - // 0; // Tripcount for the teams / distribute loop, 0 otherwise. - // struct { + // 0; // Tripcount for the teams / distribute loop, 0 otherwise. + // struct { // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. // uint64_t Unused : 62; @@ -131,12 +141,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { // // The number of threads (for x,y,z dimension). // uint32_t ThreadLimit[3] = {0, 0, 0}; // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested. - //}; - let kernel_elements = - vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; +} + +impl KernelArgsTy { + const OFFLOAD_VERSION: u64 = 3; + const FLAGS: u64 = 0; + const TRIPCOUNT: u64 = 0; + fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type { + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let tarr = cx.type_array(ti32, 3); + + let kernel_elements = + vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + kernel_arguments_ty + } - cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); - kernel_arguments_ty + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + num_args: u64, + memtransfer_types: &[&'ll Value], + geps: [&'ll Value; 3], + ) -> [(Align, &'ll Value); 13] { + let four = Align::from_bytes(4).expect("4 Byte alignment should work"); + let eight = Align::EIGHT; + let mut values = vec![]; + values.push((four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION))); + values.push((four, cx.get_const_i32(num_args))); + values.push((eight, geps[0])); + values.push((eight, geps[1])); + values.push((eight, geps[2])); + values.push((eight, memtransfer_types[0])); + // The next two are debug infos. FIXME(offload): set them + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT))); + values.push((eight, cx.get_const_i64(KernelArgsTy::FLAGS))); + let ti32 = cx.type_i32(); + let ci32_0 = cx.get_const_i32(0); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); + values.push((four, cx.get_const_i32(0))); + values.try_into().expect("tgt_kernel_arguments construction failed") + } } fn gen_tgt_data_mappers<'ll>( @@ -242,19 +293,10 @@ fn gen_define_handling<'ll>( let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); llvm::set_alignment(llglobal, Align::ONE); llvm::set_section(llglobal, c".llvm.rodata.offloading"); - - // Not actively used yet, for calling real kernels let name = format!(".offloading.entry.kernel_{num}"); // See the __tgt_offload_entry documentation above. - let reserved = cx.get_const_i64(0); - let version = cx.get_const_i16(1); - let kind = cx.get_const_i16(1); - let flags = cx.get_const_i32(0); - let size = cx.get_const_i64(0); - let data = cx.get_const_i64(0); - let aux_addr = cx.const_null(cx.type_ptr()); - let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]; + let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); let initializer = crate::common::named_struct(offload_entry_ty, &elems); let c_name = CString::new(name).unwrap(); @@ -316,7 +358,7 @@ fn gen_call_handling<'ll>( let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); - let tgt_kernel_decl = gen_tgt_kernel_global(&cx); + let tgt_kernel_decl = KernelArgsTy::new_decl(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); let main_fn = cx.get_function("main"); @@ -404,19 +446,19 @@ fn gen_call_handling<'ll>( a1: &'ll Value, a2: &'ll Value, a4: &'ll Value, - ) -> (&'ll Value, &'ll Value, &'ll Value) { + ) -> [&'ll Value; 3] { let i32_0 = cx.get_const_i32(0); let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); - (gep1, gep2, gep3) + [gep1, gep2, gep3] } fn generate_mapper_call<'a, 'll>( builder: &mut SBuilder<'a, 'll>, cx: &'ll SimpleCx<'ll>, - geps: (&'ll Value, &'ll Value, &'ll Value), + geps: [&'ll Value; 3], o_type: &'ll Value, fn_to_call: &'ll Value, fn_ty: &'ll Type, @@ -427,7 +469,7 @@ fn gen_call_handling<'ll>( let i64_max = cx.get_const_i64(u64::MAX); let num_args = cx.get_const_i32(num_args); let args = - vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr]; + vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr]; builder.call(fn_ty, fn_to_call, &args, None); } @@ -436,36 +478,20 @@ fn gen_call_handling<'ll>( let o = memtransfer_types[0]; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); + let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps); // Step 3) - let mut values = vec![]; - let offload_version = cx.get_const_i32(3); - values.push((4, offload_version)); - values.push((4, cx.get_const_i32(num_args))); - values.push((8, geps.0)); - values.push((8, geps.1)); - values.push((8, geps.2)); - values.push((8, memtransfer_types[0])); - // The next two are debug infos. FIXME(offload) set them - values.push((8, cx.const_null(cx.type_ptr()))); - values.push((8, cx.const_null(cx.type_ptr()))); - values.push((8, cx.get_const_i64(0))); - values.push((8, cx.get_const_i64(0))); - let ti32 = cx.type_i32(); - let ci32_0 = cx.get_const_i32(0); - values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); - values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); - values.push((4, cx.get_const_i32(0))); - + // Here we fill the KernelArgsTy, see the documentation above for (i, value) in values.iter().enumerate() { let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); - builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap()); + builder.store(value.1, ptr, value.0); } let args = vec![ s_ident_t, - // MAX == -1 - cx.get_const_i64(u64::MAX), + // FIXME(offload) give users a way to select which GPU to use. + cx.get_const_i64(u64::MAX), // MAX == -1. + // FIXME(offload): Don't hardcode the numbers of threads in the future. cx.get_const_i32(2097152), cx.get_const_i32(256), region_ids[0], @@ -480,19 +506,14 @@ fn gen_call_handling<'ll>( } // Step 4) - //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; - let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); drop(builder); + // FIXME(offload) The issue is that we right now add a call to the gpu version of the function, + // and then delete the call to the CPU version. In the future, we should use an intrinsic which + // directly resolves to a call to the GPU version. unsafe { llvm::LLVMDeleteFunction(called) }; - - // With this we generated the following begin and end mappers. We could easily generate the - // update mapper in an update. - // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) - // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null) - // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null) } From cdbbe9c85b6b79dc8639cb0827d0cfb5b4edc09d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Aug 2025 19:49:40 -0700 Subject: [PATCH 3/4] fix device code generation --- compiler/rustc_codegen_llvm/src/back/lto.rs | 2 +- compiler/rustc_codegen_llvm/src/back/write.rs | 81 +++++++++++++++++++ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 7 ++ .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 52 ++++++++++++ compiler/rustc_target/src/callconv/mod.rs | 1 + 5 files changed, 142 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 853d0295238e6..e99e0affbac35 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -585,7 +585,7 @@ pub(crate) fn run_pass_manager( write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } - if enable_gpu && !thin { + if enable_gpu && !thin && !(cgcx.target_arch == "nvptx64" || cgcx.target_arch == "amdgpu") { let cx = SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx); diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 62998003ca114..771822d7a2e95 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -653,6 +653,87 @@ pub(crate) unsafe fn llvm_optimize( None }; + fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) { + unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) }; + //unsafe {llvm::LLVMDumpModule(m);} + //unsafe { + // // Get the old function type + // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn); + // dbg!(&old_fn_ty); + // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty); + // dbg!(&old_param_count); + + // // Get the old parameter types + // let mut old_param_types = Vec::with_capacity(old_param_count as usize); + // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr()); + // old_param_types.set_len(old_param_count as usize); + + // // Create the new parameter list, with ptr as the first argument + // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0); + // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); + // new_param_types.push(ptr_ty); + // for old_param in old_param_types { + // new_param_types.push(old_param); + // } + // dbg!(&new_param_types); + + // // Create the new function type + // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty); + // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0); + // dbg!(&new_fn_ty); + + // // Create the new function + // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); + // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap(); + // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name); + // let new_fn_cstr = CString::new(new_fn_name).unwrap(); + // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty); + // dbg!(&new_fn); + // let a0 = llvm::LLVMGetParam(new_fn, 0); + // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len()); + // dbg!(&new_fn); + + // // Move basic blocks + // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn); + // //dbg!(&bb); + // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); + // //while !bb.is_null() { + // // let next = llvm::LLVMGetNextBasicBlock(bb); + // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); + // // bb = next; + // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ... + // let old_n = llvm::LLVMCountParams(old_fn); + // for i in 0..old_n { + // let old_arg = llvm::LLVMGetParam(old_fn, i); + // let new_arg = llvm::LLVMGetParam(new_fn, i + 1); + // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg); + // } + + // // Copy linkage and visibility + // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn)); + // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn)); + + // // Replace all uses of old_fn with new_fn (RAUW) + // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); + + // // Optionally, remove the old function + // llvm::LLVMDeleteFunction(old_fn); + //} + } + + let consider_offload = config.offload.contains(&config::Offload::Enable); + if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") { + for num in 0..9 { + let name = format!("kernel_{num}"); + let c_name = CString::new(name).unwrap(); + if let Some(kernel) = + unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) } + { + handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel); + } + } + } + let mut llvm_profiler = cgcx .prof .llvm_recording_enabled() diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 5dead7f4e7ee5..0322be1180c3c 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1201,6 +1201,11 @@ unsafe extern "C" { // Operations on functions pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); + pub(crate) fn LLVMAddFunction<'a>( + Mod: &'a Module, + Name: *const c_char, + FunctionTy: &'a Type, + ) -> &'a Value; pub(crate) fn LLVMDeleteFunction(Fn: &Value); // Operations about llvm intrinsics @@ -1219,6 +1224,7 @@ unsafe extern "C" { // Operations on basic blocks pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value; + pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock); pub(crate) fn LLVMAppendBasicBlockInContext<'a>( C: &'a Context, Fn: &'a Value, @@ -1892,6 +1898,7 @@ unsafe extern "C" { ) -> &Attribute; // Operations on functions + pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value); pub(crate) fn LLVMRustGetOrInsertFunction<'a>( M: &'a Module, Name: *const c_char, diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index e699e4b9c13f8..ea88fb7f05b38 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -35,6 +35,8 @@ #include "llvm/Support/Signals.h" #include "llvm/Support/Timer.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include // for raw `write` in the bad-alloc handler @@ -170,6 +172,56 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) { llvm::PrintStatistics(OS); } +extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) { + llvm::Module *module = llvm::unwrap(M); + llvm::Function *oldFn = llvm::unwrap(Fn); + + if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") { + return; + } + + // 1. Create new function type with the leading extra %dyn_ptr arg which llvm + // offload requries. + llvm::LLVMContext &ctx = module->getContext(); + llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0); + std::vector argTypes; + argTypes.push_back(dynPtrType); + + for (auto &arg : oldFn->args()) { + argTypes.push_back(arg.getType()); + } + + llvm::FunctionType *newFnType = llvm::FunctionType::get( + oldFn->getReturnType(), argTypes, oldFn->isVarArg()); + + // use a temporary .offload appendix to avoid name clashes + llvm::Function *newFn = llvm::Function::Create( + newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module); + + // Map old arguments to new arguments. We skip the first dyn_ptr argument, + // since it can't be used directly by user code. + llvm::ValueToValueMapTy vmap; + auto newArgIt = newFn->arg_begin(); + newArgIt->setName("dyn_ptr"); + ++newArgIt; // skip %dyn_ptr + for (auto &oldArg : oldFn->args()) { + vmap[&oldArg] = &*newArgIt++; + } + + llvm::SmallVector returns; + llvm::CloneFunctionInto(newFn, oldFn, vmap, + llvm::CloneFunctionChangeType::LocalChangesOnly, + returns); + newFn->setLinkage(oldFn->getLinkage()); + newFn->setVisibility(oldFn->getVisibility()); + + // Replace uses, delete old function, and reset name to the original one. + oldFn->replaceAllUsesWith(newFn); + auto name = oldFn->getName(); + oldFn->eraseFromParent(); + newFn->setName(name); +} + extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name, size_t NameLen) { return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen))); diff --git a/compiler/rustc_target/src/callconv/mod.rs b/compiler/rustc_target/src/callconv/mod.rs index 5f2a6f7ba38a1..a781401e4068f 100644 --- a/compiler/rustc_target/src/callconv/mod.rs +++ b/compiler/rustc_target/src/callconv/mod.rs @@ -577,6 +577,7 @@ impl RiscvInterruptKind { /// /// The signature represented by this type may not match the MIR function signature. /// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`. +/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI. /// While this difference is rarely relevant, it should still be kept in mind. /// /// I will do my best to describe this structure, but these From dd5af930763d41b1e01a72841452bb70c9d90dd3 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Sep 2025 00:36:08 -0700 Subject: [PATCH 4/4] upgrade offload dyn_ptr handling from C++ to mostly safe Rust --- compiler/rustc_codegen_llvm/src/back/write.rs | 133 ++++++++---------- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 +- compiler/rustc_codegen_llvm/src/type_.rs | 5 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 35 +---- 4 files changed, 70 insertions(+), 106 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 771822d7a2e95..53f35dac1c54c 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -44,7 +44,7 @@ use crate::errors::{ use crate::llvm::diagnostic::OptimizationDiagnosticKind::*; use crate::llvm::{self, DiagnosticInfo}; use crate::type_::Type; -use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util}; pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError { match llvm::last_error() { @@ -653,83 +653,72 @@ pub(crate) unsafe fn llvm_optimize( None }; - fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) { - unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) }; - //unsafe {llvm::LLVMDumpModule(m);} - //unsafe { - // // Get the old function type - // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn); - // dbg!(&old_fn_ty); - // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty); - // dbg!(&old_param_count); - - // // Get the old parameter types - // let mut old_param_types = Vec::with_capacity(old_param_count as usize); - // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr()); - // old_param_types.set_len(old_param_count as usize); - - // // Create the new parameter list, with ptr as the first argument - // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0); - // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); - // new_param_types.push(ptr_ty); - // for old_param in old_param_types { - // new_param_types.push(old_param); - // } - // dbg!(&new_param_types); - - // // Create the new function type - // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty); - // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0); - // dbg!(&new_fn_ty); - - // // Create the new function - // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); - // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap(); - // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name); - // let new_fn_cstr = CString::new(new_fn_name).unwrap(); - // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty); - // dbg!(&new_fn); - // let a0 = llvm::LLVMGetParam(new_fn, 0); - // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len()); - // dbg!(&new_fn); - - // // Move basic blocks - // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn); - // //dbg!(&bb); - // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); - // //while !bb.is_null() { - // // let next = llvm::LLVMGetNextBasicBlock(bb); - // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); - // // bb = next; - // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ... - // let old_n = llvm::LLVMCountParams(old_fn); - // for i in 0..old_n { - // let old_arg = llvm::LLVMGetParam(old_fn, i); - // let new_arg = llvm::LLVMGetParam(new_fn, i + 1); - // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg); - // } - - // // Copy linkage and visibility - // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn)); - // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn)); - - // // Replace all uses of old_fn with new_fn (RAUW) - // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); - - // // Optionally, remove the old function - // llvm::LLVMDeleteFunction(old_fn); - //} + fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) { + { + let old_fn_ty = cx.get_type_of_global(old_fn); + let old_param_types = cx.func_params_types(old_fn_ty); + let old_param_count = old_param_types.len(); + if old_param_count == 0 { + return; + } + + let first_param = llvm::get_param(old_fn, 0); + let c_name = llvm::get_value_name(first_param); + let first_arg_name = str::from_utf8(&c_name).unwrap(); + // We might call llvm_optimize (and thus this code) multiple times on the same IR, + // but we shouldn't add this helper ptr multiple times. + if first_arg_name == "dyn_ptr" { + return; + } + + // Create the new parameter list, with ptr as the first argument + let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); + new_param_types.push(cx.type_ptr()); + for old_param in old_param_types { + new_param_types.push(old_param); + } + + // Create the new function type + let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) }; + let new_fn_ty = cx.type_func(&new_param_types, ret_ty); + + // Create the new function, with a temporary .offload name to avoid a name collision. + let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); + let new_fn_name = format!("{}.offload", &old_fn_name); + let new_fn = cx.add_func(&new_fn_name, new_fn_ty); + let a0 = llvm::get_param(new_fn, 0); + llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes()); + + // Here we map the old arguments to the new arguments, with an offset of 1 to make sure + // that we don't use the newly added `%dyn_ptr`. + unsafe { + llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn); + } + + llvm::set_linkage(new_fn, llvm::get_linkage(old_fn)); + llvm::set_visibility(new_fn, llvm::get_visibility(old_fn)); + + // Replace all uses of old_fn with new_fn (RAUW) + unsafe { + llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); + } + let name = llvm::get_value_name(old_fn); + unsafe { + llvm::LLVMDeleteFunction(old_fn); + } + // Now we can re-use the old name, without name collision. + llvm::set_value_name(new_fn, &name); + } } let consider_offload = config.offload.contains(&config::Offload::Enable); if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") { + let cx = + SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); for num in 0..9 { let name = format!("kernel_{num}"); - let c_name = CString::new(name).unwrap(); - if let Some(kernel) = - unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) } - { - handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel); + if let Some(kernel) = cx.get_function(&name) { + handle_offload(&cx, kernel); } } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 0322be1180c3c..18b6640556a83 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1224,7 +1224,6 @@ unsafe extern "C" { // Operations on basic blocks pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value; - pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock); pub(crate) fn LLVMAppendBasicBlockInContext<'a>( C: &'a Context, Fn: &'a Value, @@ -1898,7 +1897,7 @@ unsafe extern "C" { ) -> &Attribute; // Operations on functions - pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value); + pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value); pub(crate) fn LLVMRustGetOrInsertFunction<'a>( M: &'a Module, Name: *const c_char, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index f02d16baf94e7..3756718a28209 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -68,6 +68,11 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { unsafe { llvm::LLVMVectorType(ty, len as c_uint) } } + pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value { + let name = SmallCStr::new(name); + unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) } + } + pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> { unsafe { let n_args = llvm::LLVMCountParamTypes(ty) as usize; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index ea88fb7f05b38..d37f1ff17b5f4 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -172,31 +172,10 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) { llvm::PrintStatistics(OS); } -extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) { +extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn, LLVMValueRef NewFn) { llvm::Module *module = llvm::unwrap(M); - llvm::Function *oldFn = llvm::unwrap(Fn); - - if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") { - return; - } - - // 1. Create new function type with the leading extra %dyn_ptr arg which llvm - // offload requries. - llvm::LLVMContext &ctx = module->getContext(); - llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0); - std::vector argTypes; - argTypes.push_back(dynPtrType); - - for (auto &arg : oldFn->args()) { - argTypes.push_back(arg.getType()); - } - - llvm::FunctionType *newFnType = llvm::FunctionType::get( - oldFn->getReturnType(), argTypes, oldFn->isVarArg()); - - // use a temporary .offload appendix to avoid name clashes - llvm::Function *newFn = llvm::Function::Create( - newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module); + llvm::Function *oldFn = llvm::unwrap(OldFn); + llvm::Function *newFn = llvm::unwrap(NewFn); // Map old arguments to new arguments. We skip the first dyn_ptr argument, // since it can't be used directly by user code. @@ -212,14 +191,6 @@ extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) { llvm::CloneFunctionInto(newFn, oldFn, vmap, llvm::CloneFunctionChangeType::LocalChangesOnly, returns); - newFn->setLinkage(oldFn->getLinkage()); - newFn->setVisibility(oldFn->getVisibility()); - - // Replace uses, delete old function, and reset name to the original one. - oldFn->replaceAllUsesWith(newFn); - auto name = oldFn->getName(); - oldFn->eraseFromParent(); - newFn->setName(name); } extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,