Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ pub(crate) fn run_pass_manager(
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

if enable_gpu && !thin {
if enable_gpu && !thin && !(cgcx.target_arch == "nvptx64" || cgcx.target_arch == "amdgpu") {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
Expand Down
72 changes: 71 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::errors::{
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
use crate::llvm::{self, DiagnosticInfo};
use crate::type_::Type;
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};

pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError {
match llvm::last_error() {
Expand Down Expand Up @@ -653,6 +653,76 @@ pub(crate) unsafe fn llvm_optimize(
None
};

fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
{
let old_fn_ty = cx.get_type_of_global(old_fn);
let old_param_types = cx.func_params_types(old_fn_ty);
let old_param_count = old_param_types.len();
if old_param_count == 0 {
return;
}

let first_param = llvm::get_param(old_fn, 0);
let c_name = llvm::get_value_name(first_param);
let first_arg_name = str::from_utf8(&c_name).unwrap();
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
// but we shouldn't add this helper ptr multiple times.
if first_arg_name == "dyn_ptr" {
return;
}

// Create the new parameter list, with ptr as the first argument
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
new_param_types.push(cx.type_ptr());
for old_param in old_param_types {
new_param_types.push(old_param);
}

// Create the new function type
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);

// Create the new function, with a temporary .offload name to avoid a name collision.
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
let new_fn_name = format!("{}.offload", &old_fn_name);
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
let a0 = llvm::get_param(new_fn, 0);
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());

// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
// that we don't use the newly added `%dyn_ptr`.
unsafe {
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
}

llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));

// Replace all uses of old_fn with new_fn (RAUW)
unsafe {
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
}
let name = llvm::get_value_name(old_fn);
unsafe {
llvm::LLVMDeleteFunction(old_fn);
}
// Now we can re-use the old name, without name collision.
llvm::set_value_name(new_fn, &name);
}
}

let consider_offload = config.offload.contains(&config::Offload::Enable);
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should probably be combined to a target_is_gpu, similar to target_is_like_darwin and target_is_like_aix

let cx =
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
for num in 0..9 {
let name = format!("kernel_{num}");
if let Some(kernel) = cx.get_function(&name) {
handle_offload(&cx, kernel);
}
}
}

let mut llvm_profiler = cgcx
.prof
.llvm_recording_enabled()
Expand Down
Loading
Loading