Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {

fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
unsafe {
llvm::LLVMRustBuildMemSet(
self.llbuilder,
ptr,
align.bytes() as c_uint,
fill_byte,
size,
false,
);
llvm::LLVMBuildMemSet(self.llbuilder, ptr, align.bytes() as c_uint, fill_byte, size);
}
}
}
Expand Down Expand Up @@ -1103,17 +1096,22 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
llvm::LLVMBuildMemCpy(
self.llbuilder,
dst,
dst_align.bytes() as c_uint,
src,
src_align.bytes() as c_uint,
size,
is_volatile,
)
};

if is_volatile {
unsafe {
llvm::LLVMSetVolatile(memcpy, llvm::TRUE);
}
}

// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
Expand All @@ -1136,16 +1134,21 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
llvm::LLVMRustBuildMemMove(
let memmove = unsafe {
llvm::LLVMBuildMemMove(
self.llbuilder,
dst,
dst_align.bytes() as c_uint,
src,
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};

if is_volatile {
unsafe {
llvm::LLVMSetVolatile(memmove, llvm::TRUE);
}
}
}

Expand All @@ -1159,15 +1162,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported");
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
llvm::LLVMRustBuildMemSet(
self.llbuilder,
ptr,
align.bytes() as c_uint,
fill_byte,
size,
is_volatile,
);
let memset = unsafe {
llvm::LLVMBuildMemSet(self.llbuilder, ptr, align.bytes() as c_uint, fill_byte, size)
};

if is_volatile {
unsafe {
llvm::LLVMSetVolatile(memset, llvm::TRUE);
}
}
}

Expand Down
52 changes: 25 additions & 27 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,31 @@ unsafe extern "C" {
NumBundles: c_uint,
Name: *const c_char,
) -> &'a Value;

// Memory operations
pub(crate) fn LLVMBuildMemCpy<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Src: &'a Value,
SrcAlign: c_uint,
Size: &'a Value,
) -> &'a Value;
pub(crate) fn LLVMBuildMemMove<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Src: &'a Value,
SrcAlign: c_uint,
Size: &'a Value,
) -> &'a Value;
pub(crate) fn LLVMBuildMemSet<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Val: &'a Value,
Size: &'a Value,
) -> &'a Value;
}

// FFI bindings for `DIBuilder` functions in the LLVM-C API.
Expand Down Expand Up @@ -2016,33 +2041,6 @@ unsafe extern "C" {
pub(crate) fn LLVMRustSetAllowReassoc(Instr: &Value);

// Miscellaneous instructions
pub(crate) fn LLVMRustBuildMemCpy<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Src: &'a Value,
SrcAlign: c_uint,
Size: &'a Value,
IsVolatile: bool,
) -> &'a Value;
pub(crate) fn LLVMRustBuildMemMove<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Src: &'a Value,
SrcAlign: c_uint,
Size: &'a Value,
IsVolatile: bool,
) -> &'a Value;
pub(crate) fn LLVMRustBuildMemSet<'a>(
B: &Builder<'a>,
Dst: &'a Value,
DstAlign: c_uint,
Val: &'a Value,
Size: &'a Value,
IsVolatile: bool,
) -> &'a Value;

pub(crate) fn LLVMRustBuildVectorReduceFAdd<'a>(
B: &Builder<'a>,
Acc: &'a Value,
Expand Down
27 changes: 0 additions & 27 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,33 +1378,6 @@ LLVMRustUnpackSMDiagnostic(LLVMSMDiagnosticRef DRef, RustStringRef MessageOut,
return true;
}

extern "C" LLVMValueRef LLVMRustBuildMemCpy(LLVMBuilderRef B, LLVMValueRef Dst,
unsigned DstAlign, LLVMValueRef Src,
unsigned SrcAlign,
LLVMValueRef Size,
bool IsVolatile) {
return wrap(unwrap(B)->CreateMemCpy(unwrap(Dst), MaybeAlign(DstAlign),
unwrap(Src), MaybeAlign(SrcAlign),
unwrap(Size), IsVolatile));
}

extern "C" LLVMValueRef
LLVMRustBuildMemMove(LLVMBuilderRef B, LLVMValueRef Dst, unsigned DstAlign,
LLVMValueRef Src, unsigned SrcAlign, LLVMValueRef Size,
bool IsVolatile) {
return wrap(unwrap(B)->CreateMemMove(unwrap(Dst), MaybeAlign(DstAlign),
unwrap(Src), MaybeAlign(SrcAlign),
unwrap(Size), IsVolatile));
}

extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst,
unsigned DstAlign, LLVMValueRef Val,
LLVMValueRef Size,
bool IsVolatile) {
return wrap(unwrap(B)->CreateMemSet(unwrap(Dst), unwrap(Val), unwrap(Size),
MaybeAlign(DstAlign), IsVolatile));
}

extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B,
LLVMValueRef Fn) {
Function *F = unwrap<Function>(Fn);
Expand Down
Loading