Skip to content

Commit 3829c5c

Browse files
committed
address review, refactor ffi logic
1 parent 57f796e commit 3829c5c

15 files changed

+347
-214
lines changed

src/shims/native_lib/ffi.rs

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,65 @@ use libffi::middle::{Arg as ArgPtr, Cif, Type as FfiType};
33

44
/// Perform the actual FFI call.
55
///
6-
/// SAFETY: The `FfiArg`s passed must have been correctly instantiated (i.e. their
7-
/// type layout must match the data they point to), and the safety invariants of
8-
/// the foreign function being called must be upheld (if any).
9-
pub unsafe fn call<'a, R: libffi::high::CType>(fun: CodePtr, args: Vec<FfiArg<'a>>) -> R {
6+
/// SAFETY: The safety invariants of the foreign function being called must be
7+
/// upheld (if any).
8+
pub unsafe fn call<R: libffi::high::CType>(fun: CodePtr, args: &[OwnedArg]) -> R {
109
let mut arg_tys = vec![];
1110
let mut arg_ptrs = vec![];
1211
for arg in args {
13-
arg_tys.push(arg.ty);
14-
arg_ptrs.push(arg.ptr)
12+
arg_tys.push(arg.ty());
13+
arg_ptrs.push(arg.ptr())
1514
}
1615
let cif = Cif::new(arg_tys, R::reify().into_middle());
16+
// SAFETY: Caller upholds that the function is safe to call, and since we
17+
// were passed a slice reference we know the `OwnedArg`s won't have been
18+
// by this point.
1719
unsafe { cif.call(fun, &arg_ptrs) }
1820
}
1921

20-
/// A wrapper type for `libffi::middle::Type` which also holds a pointer to the data.
21-
pub struct FfiArg<'a> {
22-
/// The type layout information for the pointed-to data.
23-
ty: FfiType,
24-
/// A pointer to the data described in `ty`.
25-
ptr: ArgPtr,
26-
/// Lifetime of the actual pointed-to data.
27-
_p: std::marker::PhantomData<&'a [u8]>,
28-
}
29-
30-
impl<'a> FfiArg<'a> {
31-
fn new(ty: FfiType, ptr: ArgPtr) -> Self {
32-
Self { ty, ptr, _p: std::marker::PhantomData }
33-
}
34-
}
35-
36-
/// An owning form of `FfiArg`.
37-
/// We introduce this enum instead of just calling `Arg::new` and storing a list of
38-
/// `libffi::middle::Arg` directly, because the `libffi::middle::Arg` just wraps a reference to
39-
/// the value it represents and we need to store a copy of the value, and pass a reference to
40-
/// this copy to C instead.
22+
/// An argument for an FFI call.
4123
#[derive(Debug, Clone)]
42-
pub enum CArg {
24+
pub enum OwnedArg {
4325
/// Primitive type.
44-
Primitive(CPrimitive),
45-
/// Struct with its computed type layout and bytes.
46-
Struct(FfiType, Box<[u8]>),
26+
Primitive(ScalarArg),
27+
/// ADT with its computed type layout and bytes.
28+
Adt(FfiType, Box<[u8]>),
4729
}
4830

49-
impl CArg {
50-
/// Convert a `CArg` to the required FFI argument type.
51-
pub fn arg_downcast<'a>(&'a self) -> FfiArg<'a> {
31+
impl OwnedArg {
32+
/// Gets the libffi type descriptor for this argument.
33+
fn ty(&self) -> FfiType {
5234
match self {
53-
CArg::Primitive(cprim) => cprim.arg_downcast(),
35+
OwnedArg::Primitive(scalar_arg) => scalar_arg.ty(),
36+
OwnedArg::Adt(ty, _) => ty.clone(),
37+
}
38+
}
39+
40+
/// Instantiates a libffi argument pointer pointing to this argument's bytes.
41+
/// NB: Since `libffi::middle::Arg` ignores the lifetime of the reference
42+
/// it's derived from, it is up to the caller to ensure the `OwnedArg` is
43+
/// not dropped before unsafely calling `libffi::middle::Cif::call()`!
44+
fn ptr(&self) -> ArgPtr {
45+
match self {
46+
OwnedArg::Primitive(scalar_arg) => scalar_arg.ptr(),
5447
// FIXME: Using `&items[0]` to reference the whole array is definitely
5548
// unsound under SB, but we're waiting on
5649
// https://github.com/libffi-rs/libffi-rs/commit/112a37b3b6ffb35bd75241fbcc580de40ba74a73
5750
// to land in a release so that we don't need to do this.
58-
CArg::Struct(cstruct, items) => FfiArg::new(cstruct.clone(), ArgPtr::new(&items[0])),
51+
OwnedArg::Adt(_, items) => ArgPtr::new(&items[0]),
5952
}
6053
}
6154
}
6255

63-
impl From<CPrimitive> for CArg {
64-
fn from(prim: CPrimitive) -> Self {
56+
impl From<ScalarArg> for OwnedArg {
57+
fn from(prim: ScalarArg) -> Self {
6558
Self::Primitive(prim)
6659
}
6760
}
6861

6962
#[derive(Debug, Clone)]
70-
/// Enum of supported primitive arguments to external C functions.
71-
pub enum CPrimitive {
63+
/// Enum of supported scalar arguments to external C functions.
64+
pub enum ScalarArg {
7265
/// 8-bit signed integer.
7366
Int8(i8),
7467
/// 16-bit signed integer.
@@ -93,21 +86,38 @@ pub enum CPrimitive {
9386
RawPtr(*mut std::ffi::c_void),
9487
}
9588

96-
impl CPrimitive {
97-
/// Convert a primitive to the required FFI argument type.
98-
fn arg_downcast<'a>(&'a self) -> FfiArg<'a> {
89+
impl ScalarArg {
90+
/// See `OwnedArg::ty()`.
91+
fn ty(&self) -> FfiType {
92+
match self {
93+
ScalarArg::Int8(_) => FfiType::i8(),
94+
ScalarArg::Int16(_) => FfiType::i16(),
95+
ScalarArg::Int32(_) => FfiType::i32(),
96+
ScalarArg::Int64(_) => FfiType::i64(),
97+
ScalarArg::ISize(_) => FfiType::isize(),
98+
ScalarArg::UInt8(_) => FfiType::u8(),
99+
ScalarArg::UInt16(_) => FfiType::u16(),
100+
ScalarArg::UInt32(_) => FfiType::u32(),
101+
ScalarArg::UInt64(_) => FfiType::u64(),
102+
ScalarArg::USize(_) => FfiType::usize(),
103+
ScalarArg::RawPtr(_) => FfiType::pointer(),
104+
}
105+
}
106+
107+
/// See `OwnedArg::ptr()`.
108+
fn ptr(&self) -> ArgPtr {
99109
match self {
100-
CPrimitive::Int8(i) => FfiArg::new(FfiType::i8(), ArgPtr::new(i)),
101-
CPrimitive::Int16(i) => FfiArg::new(FfiType::i16(), ArgPtr::new(i)),
102-
CPrimitive::Int32(i) => FfiArg::new(FfiType::i32(), ArgPtr::new(i)),
103-
CPrimitive::Int64(i) => FfiArg::new(FfiType::i64(), ArgPtr::new(i)),
104-
CPrimitive::ISize(i) => FfiArg::new(FfiType::isize(), ArgPtr::new(i)),
105-
CPrimitive::UInt8(i) => FfiArg::new(FfiType::u8(), ArgPtr::new(i)),
106-
CPrimitive::UInt16(i) => FfiArg::new(FfiType::u16(), ArgPtr::new(i)),
107-
CPrimitive::UInt32(i) => FfiArg::new(FfiType::u32(), ArgPtr::new(i)),
108-
CPrimitive::UInt64(i) => FfiArg::new(FfiType::u64(), ArgPtr::new(i)),
109-
CPrimitive::USize(i) => FfiArg::new(FfiType::usize(), ArgPtr::new(i)),
110-
CPrimitive::RawPtr(i) => FfiArg::new(FfiType::pointer(), ArgPtr::new(i)),
110+
ScalarArg::Int8(i) => ArgPtr::new(i),
111+
ScalarArg::Int16(i) => ArgPtr::new(i),
112+
ScalarArg::Int32(i) => ArgPtr::new(i),
113+
ScalarArg::Int64(i) => ArgPtr::new(i),
114+
ScalarArg::ISize(i) => ArgPtr::new(i),
115+
ScalarArg::UInt8(i) => ArgPtr::new(i),
116+
ScalarArg::UInt16(i) => ArgPtr::new(i),
117+
ScalarArg::UInt32(i) => ArgPtr::new(i),
118+
ScalarArg::UInt64(i) => ArgPtr::new(i),
119+
ScalarArg::USize(i) => ArgPtr::new(i),
120+
ScalarArg::RawPtr(i) => ArgPtr::new(i),
111121
}
112122
}
113123
}

src/shims/native_lib/mod.rs

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ mod ffi;
2121
)]
2222
pub mod trace;
2323

24-
use self::ffi::{CArg, CPrimitive, FfiArg};
24+
use self::ffi::{OwnedArg, ScalarArg};
2525
use crate::*;
2626

2727
/// The final results of an FFI trace, containing every relevant event detected
@@ -72,12 +72,12 @@ impl AccessRange {
7272
impl<'tcx> EvalContextExtPriv<'tcx> for crate::MiriInterpCx<'tcx> {}
7373
trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
7474
/// Call native host function and return the output as an immediate.
75-
fn call_native_with_args<'a>(
75+
fn call_native_with_args(
7676
&mut self,
7777
link_name: Symbol,
7878
dest: &MPlaceTy<'tcx>,
7979
ptr: CodePtr,
80-
libffi_args: Vec<FfiArg<'a>>,
80+
libffi_args: &[OwnedArg],
8181
) -> InterpResult<'tcx, (crate::ImmTy<'tcx>, Option<MemEvents>)> {
8282
let this = self.eval_context_mut();
8383
#[cfg(target_os = "linux")]
@@ -275,95 +275,90 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
275275
}
276276

277277
/// Extract the value from the result of reading an operand from the machine
278-
/// and convert it to a `CArg`.
279-
fn op_to_ffi_arg(&self, v: &OpTy<'tcx>, tracing: bool) -> InterpResult<'tcx, CArg> {
278+
/// and convert it to a `OwnedArg`.
279+
fn op_to_ffi_arg(&self, v: &OpTy<'tcx>, tracing: bool) -> InterpResult<'tcx, OwnedArg> {
280280
let this = self.eval_context_ref();
281281
let scalar = |v| interp_ok(this.read_immediate(v)?.to_scalar());
282282
interp_ok(match v.layout.ty.kind() {
283283
// If the primitive provided can be converted to a type matching the type pattern
284-
// then create a `CArg` of this primitive value with the corresponding `CArg` constructor.
284+
// then create a `OwnedArg` of this primitive value with the corresponding `OwnedArg` constructor.
285285
// the ints
286-
ty::Int(IntTy::I8) => CPrimitive::Int8(scalar(v)?.to_i8()?).into(),
287-
ty::Int(IntTy::I16) => CPrimitive::Int16(scalar(v)?.to_i16()?).into(),
288-
ty::Int(IntTy::I32) => CPrimitive::Int32(scalar(v)?.to_i32()?).into(),
289-
ty::Int(IntTy::I64) => CPrimitive::Int64(scalar(v)?.to_i64()?).into(),
286+
ty::Int(IntTy::I8) => ScalarArg::Int8(scalar(v)?.to_i8()?).into(),
287+
ty::Int(IntTy::I16) => ScalarArg::Int16(scalar(v)?.to_i16()?).into(),
288+
ty::Int(IntTy::I32) => ScalarArg::Int32(scalar(v)?.to_i32()?).into(),
289+
ty::Int(IntTy::I64) => ScalarArg::Int64(scalar(v)?.to_i64()?).into(),
290290
ty::Int(IntTy::Isize) =>
291-
CPrimitive::ISize(scalar(v)?.to_target_isize(this)?.try_into().unwrap()).into(),
291+
ScalarArg::ISize(scalar(v)?.to_target_isize(this)?.try_into().unwrap()).into(),
292292
// the uints
293-
ty::Uint(UintTy::U8) => CPrimitive::UInt8(scalar(v)?.to_u8()?).into(),
294-
ty::Uint(UintTy::U16) => CPrimitive::UInt16(scalar(v)?.to_u16()?).into(),
295-
ty::Uint(UintTy::U32) => CPrimitive::UInt32(scalar(v)?.to_u32()?).into(),
296-
ty::Uint(UintTy::U64) => CPrimitive::UInt64(scalar(v)?.to_u64()?).into(),
293+
ty::Uint(UintTy::U8) => ScalarArg::UInt8(scalar(v)?.to_u8()?).into(),
294+
ty::Uint(UintTy::U16) => ScalarArg::UInt16(scalar(v)?.to_u16()?).into(),
295+
ty::Uint(UintTy::U32) => ScalarArg::UInt32(scalar(v)?.to_u32()?).into(),
296+
ty::Uint(UintTy::U64) => ScalarArg::UInt64(scalar(v)?.to_u64()?).into(),
297297
ty::Uint(UintTy::Usize) =>
298-
CPrimitive::USize(scalar(v)?.to_target_usize(this)?.try_into().unwrap()).into(),
298+
ScalarArg::USize(scalar(v)?.to_target_usize(this)?.try_into().unwrap()).into(),
299299
ty::RawPtr(..) => {
300300
let ptr = scalar(v)?.to_pointer(this)?;
301-
// Pointer without provenance may not access any memory anyway, skip.
302-
if let Some(prov) = ptr.provenance {
303-
// The first time this happens, print a warning.
304-
if !this.machine.native_call_mem_warned.replace(true) {
305-
// Newly set, so first time we get here.
306-
this.emit_diagnostic(NonHaltingDiagnostic::NativeCallSharedMem { tracing });
307-
}
308-
309-
this.expose_provenance(prov)?;
310-
};
301+
this.expose_and_warn(ptr.provenance, tracing)?;
311302

312303
// This relies on the `expose_provenance` in the `visit_reachable_allocs` callback
313304
// below to expose the actual interpreter-level allocation.
314-
CPrimitive::RawPtr(std::ptr::with_exposed_provenance_mut(ptr.addr().bytes_usize()))
305+
ScalarArg::RawPtr(std::ptr::with_exposed_provenance_mut(ptr.addr().bytes_usize()))
315306
.into()
316307
}
317308
// For ADTs, create an FfiType from their fields.
318309
ty::Adt(adt_def, args) => {
319-
let strukt = this.adt_to_ffitype(v.layout.ty, *adt_def, args)?;
310+
let adt = this.adt_to_ffitype(v.layout.ty, *adt_def, args)?;
320311

321312
// Copy the raw bytes backing this arg.
322313
let bytes = match v.as_mplace_or_imm() {
323314
either::Either::Left(mplace) => {
324-
// We do all of this to grab the bytes without actually
325-
// stripping provenance from them, since it'll later be
326-
// exposed recursively.
327-
let ptr = mplace.ptr();
328-
// Make sure the provenance of this allocation is exposed;
329-
// there must be one for this mplace to be valid at all.
330-
// The interpreter-level allocation itself is exposed in
331-
// visit_reachable_allocs.
332-
this.expose_provenance(ptr.provenance.unwrap())?;
333-
// Then get the actual bytes.
315+
// Get the alloc id corresponding to this mplace, alongside
316+
// a pointer that's offset to point to this particular
317+
// mplace (not one at the base addr of the allocation).
318+
let mplace_ptr = mplace.ptr();
319+
let sz = mplace.layout.size.bytes_usize();
334320
let id = this
335321
.alloc_id_from_addr(
336-
ptr.addr().bytes(),
337-
mplace.layout.size.bytes_usize().try_into().unwrap(),
338-
/* only_exposed_allocations */ true,
322+
mplace_ptr.addr().bytes(),
323+
sz.try_into().unwrap(),
324+
/* only_exposed_allocations */ false,
339325
)
340326
.unwrap();
341-
let ptr_raw = this.get_alloc_bytes_unchecked_raw(id)?;
342-
// SAFETY: We know for sure that at ptr_raw the next layout.size bytes
343-
// are part of this allocation and initialised. They might be marked as
344-
// uninit in Miri, but all bytes returned by `MiriAllocBytes` are
327+
// Expose all provenances in the allocation within the byte
328+
// range of the struct, if any.
329+
let alloc = this.get_alloc_raw(id)?;
330+
let alloc_ptr = this.get_alloc_bytes_unchecked_raw(id)?;
331+
let start_addr =
332+
mplace_ptr.addr().bytes_usize().strict_sub(alloc_ptr.addr());
333+
for byte in start_addr..start_addr.strict_add(sz) {
334+
if let Some(prov) = alloc.provenance().get(Size::from_bytes(byte), this)
335+
{
336+
this.expose_provenance(prov)?;
337+
}
338+
}
339+
// SAFETY: We know for sure that at mplace_ptr.addr() the next layout.size
340+
// bytes are part of this allocation and initialised. They might be marked
341+
// as uninit in Miri, but all bytes returned by `MiriAllocBytes` are
345342
// initialised.
346343
unsafe {
347-
std::slice::from_raw_parts(ptr_raw, mplace.layout.size.bytes_usize())
348-
.to_vec()
349-
.into_boxed_slice()
344+
std::slice::from_raw_parts(
345+
alloc_ptr.with_addr(mplace_ptr.addr().bytes_usize()),
346+
mplace.layout.size.bytes_usize(),
347+
)
348+
.to_vec()
349+
.into_boxed_slice()
350350
}
351351
}
352-
either::Either::Right(imm) => {
353-
// For immediates, we know the backing scalar is going to be 128 bits,
354-
// so we can just copy that.
355-
// TODO: is it possible for this to be a scalar pair?
356-
let scalar = imm.to_scalar();
357-
if scalar.size().bytes() > 0 {
358-
let bits = scalar.to_bits(scalar.size())?;
359-
bits.to_ne_bytes().to_vec().into_boxed_slice()
360-
} else {
361-
throw_ub_format!("attempting to pass a ZST over FFI: {}", imm.layout.ty)
362-
}
352+
either::Either::Right(_) => {
353+
// TODO: support this!
354+
throw_unsup_format!(
355+
"Immediate structs can't be passed over FFI: {}",
356+
v.layout.ty
357+
)
363358
}
364359
};
365360

366-
ffi::CArg::Struct(strukt, bytes)
361+
ffi::OwnedArg::Adt(adt, bytes)
367362
}
368363
_ => throw_unsup_format!("unsupported argument type for native call: {}", v.layout.ty),
369364
})
@@ -376,14 +371,15 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
376371
adt_def: ty::AdtDef<'tcx>,
377372
args: &'tcx ty::List<ty::GenericArg<'tcx>>,
378373
) -> InterpResult<'tcx, FfiType> {
379-
// TODO: is this correct? Maybe `repr(transparent)` when the inner field
380-
// is itself `repr(c)` is ok?
374+
// TODO: Certain non-C reprs should be okay also.
381375
if !adt_def.repr().c() {
382-
throw_ub_format!("passing a non-#[repr(C)] struct over FFI: {orig_ty}")
376+
throw_unsup_format!("passing a non-#[repr(C)] struct over FFI: {orig_ty}")
383377
}
384378
// TODO: unions, etc.
385379
if !adt_def.is_struct() {
386-
throw_unsup_format!("unsupported argument type for native call: {orig_ty} is an enum or union");
380+
throw_unsup_format!(
381+
"unsupported argument type for native call: {orig_ty} is an enum or union"
382+
);
387383
}
388384

389385
let this = self.eval_context_ref();
@@ -414,6 +410,20 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
414410
_ => throw_unsup_format!("unsupported argument type for native call: {}", ty),
415411
})
416412
}
413+
414+
fn expose_and_warn(&self, prov: Option<Provenance>, tracing: bool) -> InterpResult<'tcx> {
415+
let this = self.eval_context_ref();
416+
if let Some(prov) = prov {
417+
// The first time this happens, print a warning.
418+
if !this.machine.native_call_mem_warned.replace(true) {
419+
// Newly set, so first time we get here.
420+
this.emit_diagnostic(NonHaltingDiagnostic::NativeCallSharedMem { tracing });
421+
}
422+
423+
this.expose_provenance(prov)?;
424+
};
425+
interp_ok(())
426+
}
417427
}
418428

419429
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -443,12 +453,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
443453
let tracing = trace::Supervisor::is_enabled();
444454

445455
// Get the function arguments, copy them, and prepare the type descriptions.
446-
let mut libffi_args = Vec::<CArg>::with_capacity(args.len());
456+
let mut libffi_args = Vec::<OwnedArg>::with_capacity(args.len());
447457
for arg in args.iter() {
448-
libffi_args.push(this.op_to_carg(arg, tracing)?);
458+
libffi_args.push(this.op_to_ffi_arg(arg, tracing)?);
449459
}
450-
// Convert arguments to a libffi-compatible type.
451-
let libffi_args = libffi_args.iter().map(|arg| arg.arg_downcast()).collect::<Vec<_>>();
452460

453461
// Prepare all exposed memory (both previously exposed, and just newly exposed since a
454462
// pointer was passed as argument). Uninitialised memory is left as-is, but any data
@@ -491,7 +499,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
491499

492500
// Call the function and store output, depending on return type in the function signature.
493501
let (ret, maybe_memevents) =
494-
this.call_native_with_args(link_name, dest, code_ptr, libffi_args)?;
502+
this.call_native_with_args(link_name, dest, code_ptr, &libffi_args)?;
495503

496504
if tracing {
497505
this.tracing_apply_accesses(maybe_memevents.unwrap())?;

0 commit comments

Comments
 (0)