Skip to content

Commit 9a8af0c

Browse files
committed
address review, refactor ffi logic
1 parent c6765bf commit 9a8af0c

15 files changed

+347
-214
lines changed

src/shims/native_lib/ffi.rs

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,65 @@ use libffi::middle::{Arg as ArgPtr, Cif, Type as FfiType};
33

44
/// Perform the actual FFI call.
55
///
6-
/// SAFETY: The `FfiArg`s passed must have been correctly instantiated (i.e. their
7-
/// type layout must match the data they point to), and the safety invariants of
8-
/// the foreign function being called must be upheld (if any).
9-
pub unsafe fn call<'a, R: libffi::high::CType>(fun: CodePtr, args: Vec<FfiArg<'a>>) -> R {
6+
/// SAFETY: The safety invariants of the foreign function being called must be
7+
/// upheld (if any).
8+
pub unsafe fn call<R: libffi::high::CType>(fun: CodePtr, args: &[OwnedArg]) -> R {
109
let mut arg_tys = vec![];
1110
let mut arg_ptrs = vec![];
1211
for arg in args {
13-
arg_tys.push(arg.ty);
14-
arg_ptrs.push(arg.ptr)
12+
arg_tys.push(arg.ty());
13+
arg_ptrs.push(arg.ptr())
1514
}
1615
let cif = Cif::new(arg_tys, R::reify().into_middle());
16+
// SAFETY: Caller upholds that the function is safe to call, and since we
17+
// were passed a slice reference we know the `OwnedArg`s won't have been
18+
// by this point.
1719
unsafe { cif.call(fun, &arg_ptrs) }
1820
}
1921

20-
/// A wrapper type for `libffi::middle::Type` which also holds a pointer to the data.
21-
pub struct FfiArg<'a> {
22-
/// The type layout information for the pointed-to data.
23-
ty: FfiType,
24-
/// A pointer to the data described in `ty`.
25-
ptr: ArgPtr,
26-
/// Lifetime of the actual pointed-to data.
27-
_p: std::marker::PhantomData<&'a [u8]>,
28-
}
29-
30-
impl<'a> FfiArg<'a> {
31-
fn new(ty: FfiType, ptr: ArgPtr) -> Self {
32-
Self { ty, ptr, _p: std::marker::PhantomData }
33-
}
34-
}
35-
36-
/// An owning form of `FfiArg`.
37-
/// We introduce this enum instead of just calling `Arg::new` and storing a list of
38-
/// `libffi::middle::Arg` directly, because the `libffi::middle::Arg` just wraps a reference to
39-
/// the value it represents and we need to store a copy of the value, and pass a reference to
40-
/// this copy to C instead.
22+
/// An argument for an FFI call.
4123
#[derive(Debug, Clone)]
42-
pub enum CArg {
24+
pub enum OwnedArg {
4325
/// Primitive type.
44-
Primitive(CPrimitive),
45-
/// Struct with its computed type layout and bytes.
46-
Struct(FfiType, Box<[u8]>),
26+
Primitive(ScalarArg),
27+
/// ADT with its computed type layout and bytes.
28+
Adt(FfiType, Box<[u8]>),
4729
}
4830

49-
impl CArg {
50-
/// Convert a `CArg` to the required FFI argument type.
51-
pub fn arg_downcast<'a>(&'a self) -> FfiArg<'a> {
31+
impl OwnedArg {
32+
/// Gets the libffi type descriptor for this argument.
33+
fn ty(&self) -> FfiType {
5234
match self {
53-
CArg::Primitive(cprim) => cprim.arg_downcast(),
35+
OwnedArg::Primitive(scalar_arg) => scalar_arg.ty(),
36+
OwnedArg::Adt(ty, _) => ty.clone(),
37+
}
38+
}
39+
40+
/// Instantiates a libffi argument pointer pointing to this argument's bytes.
41+
/// NB: Since `libffi::middle::Arg` ignores the lifetime of the reference
42+
/// it's derived from, it is up to the caller to ensure the `OwnedArg` is
43+
/// not dropped before unsafely calling `libffi::middle::Cif::call()`!
44+
fn ptr(&self) -> ArgPtr {
45+
match self {
46+
OwnedArg::Primitive(scalar_arg) => scalar_arg.ptr(),
5447
// FIXME: Using `&items[0]` to reference the whole array is definitely
5548
// unsound under SB, but we're waiting on
5649
// https://github.com/libffi-rs/libffi-rs/commit/112a37b3b6ffb35bd75241fbcc580de40ba74a73
5750
// to land in a release so that we don't need to do this.
58-
CArg::Struct(cstruct, items) => FfiArg::new(cstruct.clone(), ArgPtr::new(&items[0])),
51+
OwnedArg::Adt(_, items) => ArgPtr::new(&items[0]),
5952
}
6053
}
6154
}
6255

63-
impl From<CPrimitive> for CArg {
64-
fn from(prim: CPrimitive) -> Self {
56+
impl From<ScalarArg> for OwnedArg {
57+
fn from(prim: ScalarArg) -> Self {
6558
Self::Primitive(prim)
6659
}
6760
}
6861

6962
#[derive(Debug, Clone)]
70-
/// Enum of supported primitive arguments to external C functions.
71-
pub enum CPrimitive {
63+
/// Enum of supported scalar arguments to external C functions.
64+
pub enum ScalarArg {
7265
/// 8-bit signed integer.
7366
Int8(i8),
7467
/// 16-bit signed integer.
@@ -93,21 +86,38 @@ pub enum CPrimitive {
9386
RawPtr(*mut std::ffi::c_void),
9487
}
9588

96-
impl CPrimitive {
97-
/// Convert a primitive to the required FFI argument type.
98-
fn arg_downcast<'a>(&'a self) -> FfiArg<'a> {
89+
impl ScalarArg {
90+
/// See `OwnedArg::ty()`.
91+
fn ty(&self) -> FfiType {
92+
match self {
93+
ScalarArg::Int8(_) => FfiType::i8(),
94+
ScalarArg::Int16(_) => FfiType::i16(),
95+
ScalarArg::Int32(_) => FfiType::i32(),
96+
ScalarArg::Int64(_) => FfiType::i64(),
97+
ScalarArg::ISize(_) => FfiType::isize(),
98+
ScalarArg::UInt8(_) => FfiType::u8(),
99+
ScalarArg::UInt16(_) => FfiType::u16(),
100+
ScalarArg::UInt32(_) => FfiType::u32(),
101+
ScalarArg::UInt64(_) => FfiType::u64(),
102+
ScalarArg::USize(_) => FfiType::usize(),
103+
ScalarArg::RawPtr(_) => FfiType::pointer(),
104+
}
105+
}
106+
107+
/// See `OwnedArg::ptr()`.
108+
fn ptr(&self) -> ArgPtr {
99109
match self {
100-
CPrimitive::Int8(i) => FfiArg::new(FfiType::i8(), ArgPtr::new(i)),
101-
CPrimitive::Int16(i) => FfiArg::new(FfiType::i16(), ArgPtr::new(i)),
102-
CPrimitive::Int32(i) => FfiArg::new(FfiType::i32(), ArgPtr::new(i)),
103-
CPrimitive::Int64(i) => FfiArg::new(FfiType::i64(), ArgPtr::new(i)),
104-
CPrimitive::ISize(i) => FfiArg::new(FfiType::isize(), ArgPtr::new(i)),
105-
CPrimitive::UInt8(i) => FfiArg::new(FfiType::u8(), ArgPtr::new(i)),
106-
CPrimitive::UInt16(i) => FfiArg::new(FfiType::u16(), ArgPtr::new(i)),
107-
CPrimitive::UInt32(i) => FfiArg::new(FfiType::u32(), ArgPtr::new(i)),
108-
CPrimitive::UInt64(i) => FfiArg::new(FfiType::u64(), ArgPtr::new(i)),
109-
CPrimitive::USize(i) => FfiArg::new(FfiType::usize(), ArgPtr::new(i)),
110-
CPrimitive::RawPtr(i) => FfiArg::new(FfiType::pointer(), ArgPtr::new(i)),
110+
ScalarArg::Int8(i) => ArgPtr::new(i),
111+
ScalarArg::Int16(i) => ArgPtr::new(i),
112+
ScalarArg::Int32(i) => ArgPtr::new(i),
113+
ScalarArg::Int64(i) => ArgPtr::new(i),
114+
ScalarArg::ISize(i) => ArgPtr::new(i),
115+
ScalarArg::UInt8(i) => ArgPtr::new(i),
116+
ScalarArg::UInt16(i) => ArgPtr::new(i),
117+
ScalarArg::UInt32(i) => ArgPtr::new(i),
118+
ScalarArg::UInt64(i) => ArgPtr::new(i),
119+
ScalarArg::USize(i) => ArgPtr::new(i),
120+
ScalarArg::RawPtr(i) => ArgPtr::new(i),
111121
}
112122
}
113123
}

src/shims/native_lib/mod.rs

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ mod ffi;
2121
)]
2222
pub mod trace;
2323

24-
use self::ffi::{CArg, CPrimitive, FfiArg};
24+
use self::ffi::{OwnedArg, ScalarArg};
2525
use crate::*;
2626

2727
/// The final results of an FFI trace, containing every relevant event detected
@@ -72,12 +72,12 @@ impl AccessRange {
7272
impl<'tcx> EvalContextExtPriv<'tcx> for crate::MiriInterpCx<'tcx> {}
7373
trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
7474
/// Call native host function and return the output as an immediate.
75-
fn call_native_with_args<'a>(
75+
fn call_native_with_args(
7676
&mut self,
7777
link_name: Symbol,
7878
dest: &MPlaceTy<'tcx>,
7979
ptr: CodePtr,
80-
libffi_args: Vec<FfiArg<'a>>,
80+
libffi_args: &[OwnedArg],
8181
) -> InterpResult<'tcx, (crate::ImmTy<'tcx>, Option<MemEvents>)> {
8282
let this = self.eval_context_mut();
8383
#[cfg(target_os = "linux")]
@@ -271,95 +271,90 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
271271
}
272272

273273
/// Extract the value from the result of reading an operand from the machine
274-
/// and convert it to a `CArg`.
275-
fn op_to_ffi_arg(&self, v: &OpTy<'tcx>, tracing: bool) -> InterpResult<'tcx, CArg> {
274+
/// and convert it to a `OwnedArg`.
275+
fn op_to_ffi_arg(&self, v: &OpTy<'tcx>, tracing: bool) -> InterpResult<'tcx, OwnedArg> {
276276
let this = self.eval_context_ref();
277277
let scalar = |v| interp_ok(this.read_immediate(v)?.to_scalar());
278278
interp_ok(match v.layout.ty.kind() {
279279
// If the primitive provided can be converted to a type matching the type pattern
280-
// then create a `CArg` of this primitive value with the corresponding `CArg` constructor.
280+
// then create a `OwnedArg` of this primitive value with the corresponding `OwnedArg` constructor.
281281
// the ints
282-
ty::Int(IntTy::I8) => CPrimitive::Int8(scalar(v)?.to_i8()?).into(),
283-
ty::Int(IntTy::I16) => CPrimitive::Int16(scalar(v)?.to_i16()?).into(),
284-
ty::Int(IntTy::I32) => CPrimitive::Int32(scalar(v)?.to_i32()?).into(),
285-
ty::Int(IntTy::I64) => CPrimitive::Int64(scalar(v)?.to_i64()?).into(),
282+
ty::Int(IntTy::I8) => ScalarArg::Int8(scalar(v)?.to_i8()?).into(),
283+
ty::Int(IntTy::I16) => ScalarArg::Int16(scalar(v)?.to_i16()?).into(),
284+
ty::Int(IntTy::I32) => ScalarArg::Int32(scalar(v)?.to_i32()?).into(),
285+
ty::Int(IntTy::I64) => ScalarArg::Int64(scalar(v)?.to_i64()?).into(),
286286
ty::Int(IntTy::Isize) =>
287-
CPrimitive::ISize(scalar(v)?.to_target_isize(this)?.try_into().unwrap()).into(),
287+
ScalarArg::ISize(scalar(v)?.to_target_isize(this)?.try_into().unwrap()).into(),
288288
// the uints
289-
ty::Uint(UintTy::U8) => CPrimitive::UInt8(scalar(v)?.to_u8()?).into(),
290-
ty::Uint(UintTy::U16) => CPrimitive::UInt16(scalar(v)?.to_u16()?).into(),
291-
ty::Uint(UintTy::U32) => CPrimitive::UInt32(scalar(v)?.to_u32()?).into(),
292-
ty::Uint(UintTy::U64) => CPrimitive::UInt64(scalar(v)?.to_u64()?).into(),
289+
ty::Uint(UintTy::U8) => ScalarArg::UInt8(scalar(v)?.to_u8()?).into(),
290+
ty::Uint(UintTy::U16) => ScalarArg::UInt16(scalar(v)?.to_u16()?).into(),
291+
ty::Uint(UintTy::U32) => ScalarArg::UInt32(scalar(v)?.to_u32()?).into(),
292+
ty::Uint(UintTy::U64) => ScalarArg::UInt64(scalar(v)?.to_u64()?).into(),
293293
ty::Uint(UintTy::Usize) =>
294-
CPrimitive::USize(scalar(v)?.to_target_usize(this)?.try_into().unwrap()).into(),
294+
ScalarArg::USize(scalar(v)?.to_target_usize(this)?.try_into().unwrap()).into(),
295295
ty::RawPtr(..) => {
296296
let ptr = scalar(v)?.to_pointer(this)?;
297-
// Pointer without provenance may not access any memory anyway, skip.
298-
if let Some(prov) = ptr.provenance {
299-
// The first time this happens, print a warning.
300-
if !this.machine.native_call_mem_warned.replace(true) {
301-
// Newly set, so first time we get here.
302-
this.emit_diagnostic(NonHaltingDiagnostic::NativeCallSharedMem { tracing });
303-
}
304-
305-
this.expose_provenance(prov)?;
306-
};
297+
this.expose_and_warn(ptr.provenance, tracing)?;
307298

308299
// This relies on the `expose_provenance` in the `visit_reachable_allocs` callback
309300
// below to expose the actual interpreter-level allocation.
310-
CPrimitive::RawPtr(std::ptr::with_exposed_provenance_mut(ptr.addr().bytes_usize()))
301+
ScalarArg::RawPtr(std::ptr::with_exposed_provenance_mut(ptr.addr().bytes_usize()))
311302
.into()
312303
}
313304
// For ADTs, create an FfiType from their fields.
314305
ty::Adt(adt_def, args) => {
315-
let strukt = this.adt_to_ffitype(v.layout.ty, *adt_def, args)?;
306+
let adt = this.adt_to_ffitype(v.layout.ty, *adt_def, args)?;
316307

317308
// Copy the raw bytes backing this arg.
318309
let bytes = match v.as_mplace_or_imm() {
319310
either::Either::Left(mplace) => {
320-
// We do all of this to grab the bytes without actually
321-
// stripping provenance from them, since it'll later be
322-
// exposed recursively.
323-
let ptr = mplace.ptr();
324-
// Make sure the provenance of this allocation is exposed;
325-
// there must be one for this mplace to be valid at all.
326-
// The interpreter-level allocation itself is exposed in
327-
// visit_reachable_allocs.
328-
this.expose_provenance(ptr.provenance.unwrap())?;
329-
// Then get the actual bytes.
311+
// Get the alloc id corresponding to this mplace, alongside
312+
// a pointer that's offset to point to this particular
313+
// mplace (not one at the base addr of the allocation).
314+
let mplace_ptr = mplace.ptr();
315+
let sz = mplace.layout.size.bytes_usize();
330316
let id = this
331317
.alloc_id_from_addr(
332-
ptr.addr().bytes(),
333-
mplace.layout.size.bytes_usize().try_into().unwrap(),
334-
/* only_exposed_allocations */ true,
318+
mplace_ptr.addr().bytes(),
319+
sz.try_into().unwrap(),
320+
/* only_exposed_allocations */ false,
335321
)
336322
.unwrap();
337-
let ptr_raw = this.get_alloc_bytes_unchecked_raw(id)?;
338-
// SAFETY: We know for sure that at ptr_raw the next layout.size bytes
339-
// are part of this allocation and initialised. They might be marked as
340-
// uninit in Miri, but all bytes returned by `MiriAllocBytes` are
323+
// Expose all provenances in the allocation within the byte
324+
// range of the struct, if any.
325+
let alloc = this.get_alloc_raw(id)?;
326+
let alloc_ptr = this.get_alloc_bytes_unchecked_raw(id)?;
327+
let start_addr =
328+
mplace_ptr.addr().bytes_usize().strict_sub(alloc_ptr.addr());
329+
for byte in start_addr..start_addr.strict_add(sz) {
330+
if let Some(prov) = alloc.provenance().get(Size::from_bytes(byte), this)
331+
{
332+
this.expose_provenance(prov)?;
333+
}
334+
}
335+
// SAFETY: We know for sure that at mplace_ptr.addr() the next layout.size
336+
// bytes are part of this allocation and initialised. They might be marked
337+
// as uninit in Miri, but all bytes returned by `MiriAllocBytes` are
341338
// initialised.
342339
unsafe {
343-
std::slice::from_raw_parts(ptr_raw, mplace.layout.size.bytes_usize())
344-
.to_vec()
345-
.into_boxed_slice()
340+
std::slice::from_raw_parts(
341+
alloc_ptr.with_addr(mplace_ptr.addr().bytes_usize()),
342+
mplace.layout.size.bytes_usize(),
343+
)
344+
.to_vec()
345+
.into_boxed_slice()
346346
}
347347
}
348-
either::Either::Right(imm) => {
349-
// For immediates, we know the backing scalar is going to be 128 bits,
350-
// so we can just copy that.
351-
// TODO: is it possible for this to be a scalar pair?
352-
let scalar = imm.to_scalar();
353-
if scalar.size().bytes() > 0 {
354-
let bits = scalar.to_bits(scalar.size())?;
355-
bits.to_ne_bytes().to_vec().into_boxed_slice()
356-
} else {
357-
throw_ub_format!("attempting to pass a ZST over FFI: {}", imm.layout.ty)
358-
}
348+
either::Either::Right(_) => {
349+
// TODO: support this!
350+
throw_unsup_format!(
351+
"Immediate structs can't be passed over FFI: {}",
352+
v.layout.ty
353+
)
359354
}
360355
};
361356

362-
ffi::CArg::Struct(strukt, bytes)
357+
ffi::OwnedArg::Adt(adt, bytes)
363358
}
364359
_ => throw_unsup_format!("unsupported argument type for native call: {}", v.layout.ty),
365360
})
@@ -372,14 +367,15 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
372367
adt_def: ty::AdtDef<'tcx>,
373368
args: &'tcx ty::List<ty::GenericArg<'tcx>>,
374369
) -> InterpResult<'tcx, FfiType> {
375-
// TODO: is this correct? Maybe `repr(transparent)` when the inner field
376-
// is itself `repr(c)` is ok?
370+
// TODO: Certain non-C reprs should be okay also.
377371
if !adt_def.repr().c() {
378-
throw_ub_format!("passing a non-#[repr(C)] struct over FFI: {orig_ty}")
372+
throw_unsup_format!("passing a non-#[repr(C)] struct over FFI: {orig_ty}")
379373
}
380374
// TODO: unions, etc.
381375
if !adt_def.is_struct() {
382-
throw_unsup_format!("unsupported argument type for native call: {orig_ty} is an enum or union");
376+
throw_unsup_format!(
377+
"unsupported argument type for native call: {orig_ty} is an enum or union"
378+
);
383379
}
384380

385381
let this = self.eval_context_ref();
@@ -410,6 +406,20 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
410406
_ => throw_unsup_format!("unsupported argument type for native call: {}", ty),
411407
})
412408
}
409+
410+
fn expose_and_warn(&self, prov: Option<Provenance>, tracing: bool) -> InterpResult<'tcx> {
411+
let this = self.eval_context_ref();
412+
if let Some(prov) = prov {
413+
// The first time this happens, print a warning.
414+
if !this.machine.native_call_mem_warned.replace(true) {
415+
// Newly set, so first time we get here.
416+
this.emit_diagnostic(NonHaltingDiagnostic::NativeCallSharedMem { tracing });
417+
}
418+
419+
this.expose_provenance(prov)?;
420+
};
421+
interp_ok(())
422+
}
413423
}
414424

415425
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -439,12 +449,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
439449
let tracing = trace::Supervisor::is_enabled();
440450

441451
// Get the function arguments, copy them, and prepare the type descriptions.
442-
let mut libffi_args = Vec::<CArg>::with_capacity(args.len());
452+
let mut libffi_args = Vec::<OwnedArg>::with_capacity(args.len());
443453
for arg in args.iter() {
444-
libffi_args.push(this.op_to_carg(arg, tracing)?);
454+
libffi_args.push(this.op_to_ffi_arg(arg, tracing)?);
445455
}
446-
// Convert arguments to a libffi-compatible type.
447-
let libffi_args = libffi_args.iter().map(|arg| arg.arg_downcast()).collect::<Vec<_>>();
448456

449457
// Prepare all exposed memory (both previously exposed, and just newly exposed since a
450458
// pointer was passed as argument). Uninitialised memory is left as-is, but any data
@@ -487,7 +495,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
487495

488496
// Call the function and store output, depending on return type in the function signature.
489497
let (ret, maybe_memevents) =
490-
this.call_native_with_args(link_name, dest, code_ptr, libffi_args)?;
498+
this.call_native_with_args(link_name, dest, code_ptr, &libffi_args)?;
491499

492500
if tracing {
493501
this.tracing_apply_accesses(maybe_memevents.unwrap())?;

0 commit comments

Comments
 (0)