Skip to content

Commit 5958ebe

Browse files
committed
add various wrappers for gpu code generation
1 parent 6340164 commit 5958ebe

File tree

5 files changed

+140
-2
lines changed

5 files changed

+140
-2
lines changed

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::{iter, ptr};
44

55
pub(crate) mod autodiff;
6+
pub(crate) mod gpu_offload;
67

78
use libc::{c_char, c_uint, size_t};
89
use rustc_abi as abi;
@@ -117,6 +118,74 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
117118
}
118119
bx
119120
}
121+
122+
// The generic builder has less functionality and thus (unlike the other alloca) we can not
123+
// easily jump to the beginning of the function to place our allocas there. We trust the user
124+
// to manually do that. FIXME(offload): improve the genericCx and add more llvm wrappers to
125+
// handle this.
126+
pub(crate) fn direct_alloca(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
127+
let val = unsafe {
128+
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
129+
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
130+
// Cast to default addrspace if necessary
131+
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
132+
};
133+
if name != "" {
134+
let name = std::ffi::CString::new(name).unwrap();
135+
llvm::set_value_name(val, &name.as_bytes());
136+
}
137+
val
138+
}
139+
140+
pub(crate) fn inbounds_gep(
141+
&mut self,
142+
ty: &'ll Type,
143+
ptr: &'ll Value,
144+
indices: &[&'ll Value],
145+
) -> &'ll Value {
146+
unsafe {
147+
llvm::LLVMBuildGEPWithNoWrapFlags(
148+
self.llbuilder,
149+
ty,
150+
ptr,
151+
indices.as_ptr(),
152+
indices.len() as c_uint,
153+
UNNAMED,
154+
GEPNoWrapFlags::InBounds,
155+
)
156+
}
157+
}
158+
159+
pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value {
160+
debug!("Store {:?} -> {:?}", val, ptr);
161+
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
162+
unsafe {
163+
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
164+
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
165+
store
166+
}
167+
}
168+
169+
pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
170+
unsafe {
171+
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
172+
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
173+
load
174+
}
175+
}
176+
177+
fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
178+
unsafe {
179+
llvm::LLVMRustBuildMemSet(
180+
self.llbuilder,
181+
ptr,
182+
align.bytes() as c_uint,
183+
fill_byte,
184+
size,
185+
false,
186+
);
187+
}
188+
}
120189
}
121190

122191
/// Empty string, to be used where LLVM expects an instruction name, indicating

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ pub(crate) unsafe fn create_module<'ll>(
211211

212212
// Ensure the data-layout values hardcoded remain the defaults.
213213
{
214-
let tm = crate::back::write::create_informational_target_machine(tcx.sess, false);
214+
let tm = crate::back::write::create_informational_target_machine(sess, false);
215215
unsafe {
216216
llvm::LLVMRustSetDataLayoutFromTargetMachine(llmod, tm.raw());
217217
}
@@ -680,6 +680,22 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
680680
unsafe { llvm::LLVMConstInt(ty, val, llvm::False) }
681681
}
682682

683+
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
684+
self.get_const_int(self.type_i64(), n)
685+
}
686+
687+
pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value {
688+
self.get_const_int(self.type_i32(), n)
689+
}
690+
691+
pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value {
692+
self.get_const_int(self.type_i16(), n)
693+
}
694+
695+
pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value {
696+
self.get_const_int(self.type_i8(), n)
697+
}
698+
683699
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
684700
let name = SmallCStr::new(name);
685701
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use libc::{c_char, c_uint};
44

55
use super::MetadataKindId;
66
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
7-
use crate::llvm::Bool;
7+
use crate::llvm::{Bool, Builder};
88

99
#[link(name = "llvm-wrapper", kind = "static")]
1010
unsafe extern "C" {
@@ -31,6 +31,14 @@ unsafe extern "C" {
3131
index: c_uint,
3232
kind: AttributeKind,
3333
);
34+
pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
35+
pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
36+
pub(crate) fn LLVMRustGetFunctionCall(
37+
F: &Value,
38+
name: *const c_char,
39+
NameLen: libc::size_t,
40+
) -> Option<&Value>;
41+
3442
}
3543

3644
unsafe extern "C" {

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,11 @@ unsafe extern "C" {
11381138
Count: c_uint,
11391139
Packed: Bool,
11401140
) -> &'a Value;
1141+
pub(crate) fn LLVMConstNamedStruct<'a>(
1142+
StructTy: &'a Type,
1143+
ConstantVals: *const &'a Value,
1144+
Count: c_uint,
1145+
) -> &'a Value;
11411146
pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value;
11421147

11431148
// Constant expressions
@@ -1217,6 +1222,8 @@ unsafe extern "C" {
12171222
) -> &'a BasicBlock;
12181223

12191224
// Operations on instructions
1225+
pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock;
1226+
pub(crate) fn LLVMGetCalledValue(CallInst: &Value) -> Option<&Value>;
12201227
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
12211228
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
12221229
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;
@@ -2556,6 +2563,7 @@ unsafe extern "C" {
25562563

25572564
pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine);
25582565

2566+
pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value);
25592567
pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock);
25602568

25612569
pub(crate) fn LLVMRustSetModulePICLevel(M: &Module);

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,12 +1591,49 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst,
15911591
MaybeAlign(DstAlign), IsVolatile));
15921592
}
15931593

1594+
extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B,
1595+
LLVMValueRef Fn) {
1596+
Function *F = unwrap<Function>(Fn);
1597+
unwrap(B)->SetInsertPointPastAllocas(F);
1598+
}
15941599
extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
15951600
LLVMBasicBlockRef BB) {
15961601
auto Point = unwrap(BB)->getFirstInsertionPt();
15971602
unwrap(B)->SetInsertPoint(unwrap(BB), Point);
15981603
}
15991604

1605+
extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) {
1606+
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
1607+
unwrap(B)->SetInsertPoint(I);
1608+
}
1609+
}
1610+
1611+
extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) {
1612+
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
1613+
auto J = I->getNextNonDebugInstruction();
1614+
unwrap(B)->SetInsertPoint(J);
1615+
}
1616+
}
1617+
1618+
extern "C" LLVMValueRef
1619+
LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) {
1620+
auto targetName = StringRef(Name, NameLen);
1621+
Function *F = unwrap<Function>(Fn);
1622+
for (auto &BB : *F) {
1623+
for (auto &I : BB) {
1624+
if (auto *callInst = llvm::dyn_cast<llvm::CallBase>(&I)) {
1625+
const llvm::Function *calledFunc = callInst->getCalledFunction();
1626+
if (calledFunc && calledFunc->getName() == targetName) {
1627+
// Found a call to the target function
1628+
return wrap(callInst);
1629+
}
1630+
}
1631+
}
1632+
}
1633+
1634+
return nullptr;
1635+
}
1636+
16001637
extern "C" bool LLVMRustConstIntGetZExtValue(LLVMValueRef CV, uint64_t *value) {
16011638
auto C = unwrap<llvm::ConstantInt>(CV);
16021639
if (C->getBitWidth() > 64)

0 commit comments

Comments
 (0)