Skip to content

Commit 4a35d3a

Browse files
committed
host might be ready?
1 parent e8d1d8a commit 4a35d3a

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,24 @@ pub(crate) fn handle_gpu_code<'ll>(
1818
// The offload memory transfer type for each kernel
1919
let mut o_types = vec![];
2020
let mut kernels = vec![];
21+
let mut region_ids = vec![];
2122
let offload_entry_ty = add_tgt_offload_entry(&cx);
2223
for num in 0..9 {
2324
let kernel = cx.get_function(&format!("kernel_{num}"));
2425
if let Some(kernel) = kernel {
25-
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
26+
let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
27+
o_types.push(o);
28+
region_ids.push(k);
2629
kernels.push(kernel);
2730
}
2831
}
29-
generate_launcher(&cx);
30-
gen_call_handling(&cx, &kernels, &o_types);
32+
gen_call_handling(&cx, &kernels, &o_types, &region_ids);
3133
crate::builder::gpu_wrapper::gen_image_wrapper_module(&cgcx);
3234
}
3335

3436
// ; Function Attrs: nounwind
3537
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
36-
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
38+
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
3739
let tptr = cx.type_ptr();
3840
let ti64 = cx.type_i64();
3941
let ti32 = cx.type_i32();
@@ -43,7 +45,7 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
4345
let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
4446
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
4547
attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
46-
tgt_decl
48+
(tgt_decl, tgt_fn_ty)
4749
}
4850

4951
// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
@@ -204,7 +206,7 @@ fn gen_define_handling<'ll>(
204206
kernel: &'ll llvm::Value,
205207
offload_entry_ty: &'ll llvm::Type,
206208
num: i64,
207-
) -> &'ll llvm::Value {
209+
) -> (&'ll llvm::Value, &'ll llvm::Value) {
208210
let types = cx.func_params_types(cx.get_type_of_global(kernel));
209211
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
210212
// reference) types.
@@ -262,7 +264,7 @@ fn gen_define_handling<'ll>(
262264
llvm::set_alignment(llglobal, Align::ONE);
263265
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
264266
llvm::set_section(llglobal, &c_section_name);
265-
o_types
267+
(o_types, region_id)
266268
}
267269

268270
pub(crate) fn declare_offload_fn<'ll>(
@@ -304,7 +306,9 @@ fn gen_call_handling<'ll>(
304306
cx: &'ll SimpleCx<'_>,
305307
_kernels: &[&'ll llvm::Value],
306308
o_types: &[&'ll llvm::Value],
309+
region_ids: &[&'ll llvm::Value],
307310
) {
311+
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
308312
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
309313
let tptr = cx.type_ptr();
310314
let ti32 = cx.type_i32();
@@ -491,8 +495,26 @@ fn gen_call_handling<'ll>(
491495
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
492496
}
493497

498+
let args = vec![
499+
s_ident_t,
500+
// MAX == -1
501+
cx.get_const_i64(u64::MAX),
502+
cx.get_const_i32(2097152),
503+
cx.get_const_i32(256),
504+
region_ids[0],
505+
a5,
506+
];
507+
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
508+
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
509+
unsafe {
510+
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
511+
dbg!(&next);
512+
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
513+
llvm::LLVMInstructionEraseFromParent(next);
514+
}
515+
494516
// Step 4)
495-
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
517+
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
496518

497519
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
498520
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,8 @@ unsafe extern "C" {
12521252
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
12531253
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
12541254
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;
1255+
pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>;
1256+
pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value);
12551257

12561258
// Operations on call sites
12571259
pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint);

tests/codegen-llvm/gpu_offload/gpu_host.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,37 @@ fn main() {
6060
// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
6161
// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0
6262
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null)
63-
// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x)
64-
// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
65-
// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
66-
// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0
67-
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null)
63+
// CHECK-NEXT: %9 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0
64+
// CHECK-NEXT: store i32 3, ptr %9, align 4
65+
// CHECK-NEXT: %10 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1
66+
// CHECK-NEXT: store i32 3, ptr %10, align 4
67+
// CHECK-NEXT: %11 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2
68+
// CHECK-NEXT: store ptr %6, ptr %11, align 8
69+
// CHECK-NEXT: %12 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3
70+
// CHECK-NEXT: store ptr %7, ptr %12, align 8
71+
// CHECK-NEXT: %13 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4
72+
// CHECK-NEXT: store ptr %8, ptr %13, align 8
73+
// CHECK-NEXT: %14 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5
74+
// CHECK-NEXT: store ptr @.offload_maptypes.1, ptr %14, align 8
75+
// CHECK-NEXT: %15 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6
76+
// CHECK-NEXT: store ptr null, ptr %15, align 8
77+
// CHECK-NEXT: %16 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7
78+
// CHECK-NEXT: store ptr null, ptr %16, align 8
79+
// CHECK-NEXT: %17 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8
80+
// CHECK-NEXT: store i64 0, ptr %17, align 8
81+
// CHECK-NEXT: %18 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9
82+
// CHECK-NEXT: store i64 0, ptr %18, align 8
83+
// CHECK-NEXT: %19 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10
84+
// CHECK-NEXT: store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %19, align 8
85+
// CHECK-NEXT: %20 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11
86+
// CHECK-NEXT: store [3 x i32] [i32 256, i32 0, i32 0], ptr %20, align 8
87+
// CHECK-NEXT: %21 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12
88+
// CHECK-NEXT: store i32 0, ptr %21, align 4
89+
// CHECK-NEXT: %22 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
90+
// CHECK-NEXT: %23 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
91+
// CHECK-NEXT: %24 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
92+
// CHECK-NEXT: %25 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0
93+
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %23, ptr %24, ptr %25, ptr @.offload_maptypes.1, ptr null, ptr null)
6894
// CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc)
6995
// CHECK: store ptr %x, ptr %0, align 8
7096
// CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0)

0 commit comments

Comments
 (0)