Skip to content

Commit b9334e6

Browse files
committed
addressing more feedback
1 parent f078df5 commit b9334e6

File tree

1 file changed

+59
-60
lines changed

1 file changed

+59
-60
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 59 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
88
use crate::builder::SBuilder;
99
use crate::common::AsCCharPtr;
1010
use crate::llvm::AttributePlace::Function;
11-
use crate::llvm::{self, Linkage};
11+
use crate::llvm::{self, Linkage, Type, Value};
1212
use crate::{LlvmCodegenBackend, SimpleCx, attributes};
1313

1414
pub(crate) fn handle_gpu_code<'ll>(
1515
_cgcx: &CodegenContext<LlvmCodegenBackend>,
1616
cx: &'ll SimpleCx<'_>,
1717
) {
18+
// The offload memory transfer type for each kernel
1819
let mut o_types = vec![];
1920
let mut kernels = vec![];
2021
let offload_entry_ty = add_tgt_offload_entry(&cx);
@@ -43,7 +44,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
4344

4445
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
4546
let struct_ident_ty = cx.type_named_struct("struct.ident_t");
46-
let struct_elems: Vec<&llvm::Value> = vec![
47+
let struct_elems = vec![
4748
cx.get_const_i32(0),
4849
cx.get_const_i32(2),
4950
cx.get_const_i32(0),
@@ -163,7 +164,7 @@ pub(crate) fn add_unnamed_global<'ll>(
163164
l: Linkage,
164165
) -> &'ll llvm::Value {
165166
let llglobal = add_global(cx, name, initializer, l);
166-
unsafe { llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global) };
167+
llvm::SetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
167168
llglobal
168169
}
169170

@@ -220,24 +221,20 @@ fn gen_define_handling<'ll>(
220221
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
221222
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
222223
llvm::set_alignment(llglobal, Align::ONE);
223-
let c_section_name = CString::new(".llvm.rodata.offloading").unwrap();
224-
llvm::set_section(llglobal, &c_section_name);
224+
llvm::set_section(llglobal, c".llvm.rodata.offloading");
225225

226226
// Not actively used yet, for calling real kernels
227227
let name = format!(".offloading.entry.kernel_{num}");
228-
let ci64_0 = cx.get_const_i64(0);
229-
let ci16_1 = cx.get_const_i16(1);
230-
let elems: Vec<&llvm::Value> = vec![
231-
ci64_0,
232-
ci16_1,
233-
ci16_1,
234-
cx.get_const_i32(0),
235-
region_id,
236-
llglobal,
237-
ci64_0,
238-
ci64_0,
239-
cx.const_null(cx.type_ptr()),
240-
];
228+
229+
// See the __tgt_offload_entry documentation above.
230+
let reserved = cx.get_const_i64(0);
231+
let version = cx.get_const_i16(1);
232+
let kind = cx.get_const_i16(1);
233+
let flags = cx.get_const_i32(0);
234+
let size = cx.get_const_i64(0);
235+
let data = cx.get_const_i64(0);
236+
let aux_addr = cx.const_null(cx.type_ptr());
237+
let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
241238

242239
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
243240
let c_name = CString::new(name).unwrap();
@@ -353,12 +350,7 @@ fn gen_call_handling<'ll>(
353350

354351
// Step 1)
355352
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
356-
builder.memset(
357-
tgt_bin_desc_alloca,
358-
cx.get_const_i8(0),
359-
cx.get_const_i64(32),
360-
Align::from_bytes(8).unwrap(),
361-
);
353+
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
362354

363355
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
364356
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
@@ -384,26 +376,48 @@ fn gen_call_handling<'ll>(
384376
builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
385377
}
386378

387-
// Step 2)
388-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
389-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
390-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
379+
// For now we have a very simplistic indexing scheme into our
380+
// offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
381+
fn get_geps<'a, 'll>(
382+
builder: &mut SBuilder<'a, 'll>,
383+
cx: &'ll SimpleCx<'ll>,
384+
ty: &'ll Type,
385+
ty2: &'ll Type,
386+
a1: &'ll Value,
387+
a2: &'ll Value,
388+
a4: &'ll Value,
389+
) -> (&'ll Value, &'ll Value, &'ll Value) {
390+
let i32_0 = cx.get_const_i32(0);
391+
392+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
393+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
394+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
395+
(gep1, gep2, gep3)
396+
}
391397

392-
let nullptr = cx.const_null(cx.type_ptr());
393-
let o_type = o_types[0];
398+
fn generate_mapper_call<'a, 'll>(
399+
builder: &mut SBuilder<'a, 'll>,
400+
cx: &'ll SimpleCx<'ll>,
401+
geps: (&'ll Value, &'ll Value, &'ll Value),
402+
o_type: &'ll Value,
403+
fn_to_call: &'ll Value,
404+
fn_ty: &'ll Type,
405+
num_args: u64,
406+
s_ident_t: &'ll Value,
407+
) {
408+
let nullptr = cx.const_null(cx.type_ptr());
409+
let i64_max = cx.get_const_i64(u64::MAX);
410+
let num_args = cx.get_const_i32(num_args);
411+
let args =
412+
vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
413+
builder.call(fn_ty, fn_to_call, &args, None);
414+
}
415+
416+
// Step 2)
394417
let s_ident_t = generate_at_one(&cx);
395-
let args = vec![
396-
s_ident_t,
397-
cx.get_const_i64(u64::MAX),
398-
cx.get_const_i32(num_args),
399-
gep1,
400-
gep2,
401-
gep3,
402-
o_type,
403-
nullptr,
404-
nullptr,
405-
];
406-
builder.call(fn_ty, begin_mapper_decl, &args, None);
418+
let o = o_types[0];
419+
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
420+
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
407421

408422
// Step 3)
409423
// Here we will add code for the actual kernel launches in a follow-up PR.
@@ -412,24 +426,9 @@ fn gen_call_handling<'ll>(
412426
// Step 4)
413427
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
414428

415-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
416-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
417-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
418-
419-
let nullptr = cx.const_null(cx.type_ptr());
420-
let o_type = o_types[0];
421-
let args = vec![
422-
s_ident_t,
423-
cx.get_const_i64(u64::MAX),
424-
cx.get_const_i32(num_args),
425-
gep1,
426-
gep2,
427-
gep3,
428-
o_type,
429-
nullptr,
430-
nullptr,
431-
];
432-
builder.call(fn_ty, end_mapper_decl, &args, None);
429+
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
430+
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
431+
433432
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
434433

435434
// With this we generated the following begin and end mappers. We could easily generate the

0 commit comments

Comments
 (0)