Skip to content

Commit 590abce

Browse files
committed
fix: FFI returns a string/binary instead of requiring the caller to guess the necessary memory
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent 4321359 commit 590abce

File tree

5 files changed

+154
-94
lines changed

5 files changed

+154
-94
lines changed

vortex-ffi/src/array.rs

Lines changed: 41 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
//! FFI interface for working with Vortex Arrays.
5-
use std::ffi::{c_int, c_void};
6-
use std::slice;
5+
use std::ptr;
6+
use std::sync::Arc;
77

88
use vortex::dtype::half::f16;
9-
use vortex::error::{VortexExpect, VortexUnwrap, vortex_err};
9+
use vortex::error::{VortexExpect, vortex_err};
1010
use vortex::{Array, ToCanonical};
1111

1212
use crate::arc_dyn_wrapper;
13+
use crate::binary::vx_binary;
1314
use crate::dtype::vx_dtype;
1415
use crate::error::{try_or_default, vx_error};
16+
use crate::string::vx_string;
1517

1618
arc_dyn_wrapper!(
1719
/// Base type for all Vortex arrays.
@@ -133,48 +135,42 @@ ffiarray_get_ptype!(f16);
133135
ffiarray_get_ptype!(f32);
134136
ffiarray_get_ptype!(f64);
135137

136-
/// Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
137-
/// the length in `len`.
138+
/// Return the utf-8 string at `index` in the array. The pointer will be null if the value at `index` is null.
139+
/// The caller must free the returned pointer.
138140
#[unsafe(no_mangle)]
139141
pub unsafe extern "C-unwind" fn vx_array_get_utf8(
140142
array: *const vx_array,
141143
index: u32,
142-
dst: *mut c_void,
143-
len: *mut c_int,
144-
) {
144+
) -> *const vx_string {
145145
let array = vx_array::as_ref(array);
146146
let value = array.scalar_at(index as usize);
147147
let utf8_scalar = value.as_utf8();
148148
if let Some(buffer) = utf8_scalar.value() {
149-
let bytes = buffer.as_bytes();
150-
let dst = unsafe { slice::from_raw_parts_mut(dst as *mut u8, bytes.len()) };
151-
dst.copy_from_slice(bytes);
152-
unsafe { *len = bytes.len().try_into().vortex_unwrap() };
149+
vx_string::new(Arc::from(buffer.as_str()))
150+
} else {
151+
ptr::null()
153152
}
154153
}
155154

156-
/// Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
157-
/// the length in `len`.
155+
/// Return the binary at `index` in the array. The pointer will be null if the value at `index` is null.
156+
/// The caller must free the returned pointer.
158157
#[unsafe(no_mangle)]
159158
pub unsafe extern "C-unwind" fn vx_array_get_binary(
160159
array: *const vx_array,
161160
index: u32,
162-
dst: *mut c_void,
163-
len: *mut c_int,
164-
) {
161+
) -> *const vx_binary {
165162
let array = vx_array::as_ref(array);
166163
let value = array.scalar_at(index as usize);
167-
let utf8_scalar = value.as_binary();
168-
if let Some(bytes) = utf8_scalar.value() {
169-
let dst = unsafe { slice::from_raw_parts_mut(dst as *mut u8, bytes.len()) };
170-
dst.copy_from_slice(&bytes);
171-
unsafe { *len = bytes.len().try_into().vortex_unwrap() };
164+
let binary_scalar = value.as_binary();
165+
if let Some(bytes) = binary_scalar.value() {
166+
vx_binary::new(Arc::from(bytes.as_bytes()))
167+
} else {
168+
ptr::null()
172169
}
173170
}
174171

175172
#[cfg(test)]
176173
mod tests {
177-
use std::ffi::{c_int, c_void};
178174
use std::ptr;
179175

180176
use vortex::IntoArray;
@@ -185,8 +181,10 @@ mod tests {
185181
use vortex::validity::Validity;
186182

187183
use crate::array::*;
184+
use crate::binary::vx_binary_free;
188185
use crate::dtype::{vx_dtype_get_variant, vx_dtype_variant};
189186
use crate::error::vx_error_free;
187+
use crate::string::vx_string_free;
190188

191189
#[test]
192190
fn test_simple() {
@@ -349,35 +347,17 @@ mod tests {
349347
let utf8_array = VarBinViewArray::from_iter_str(["hello", "world", "test"]);
350348
let ffi_array = vx_array::new(utf8_array.into_array());
351349

352-
let mut buffer = vec![0u8; 10];
353-
let mut len: c_int = 0;
350+
let vx_str1 = vx_array_get_utf8(ffi_array, 0);
351+
assert_eq!(vx_string::as_str(vx_str1), "hello");
352+
vx_string_free(vx_str1);
354353

355-
vx_array_get_utf8(
356-
ffi_array,
357-
0,
358-
buffer.as_mut_ptr() as *mut c_void,
359-
&raw mut len,
360-
);
361-
assert_eq!(len, 5);
362-
assert_eq!(&buffer[..5], b"hello");
363-
364-
vx_array_get_utf8(
365-
ffi_array,
366-
1,
367-
buffer.as_mut_ptr() as *mut c_void,
368-
&raw mut len,
369-
);
370-
assert_eq!(len, 5);
371-
assert_eq!(&buffer[..5], b"world");
372-
373-
vx_array_get_utf8(
374-
ffi_array,
375-
2,
376-
buffer.as_mut_ptr() as *mut c_void,
377-
&raw mut len,
378-
);
379-
assert_eq!(len, 4);
380-
assert_eq!(&buffer[..4], b"test");
354+
let vx_str2 = vx_array_get_utf8(ffi_array, 1);
355+
assert_eq!(vx_string::as_str(vx_str2), "world");
356+
vx_string_free(vx_str2);
357+
358+
let vx_str3 = vx_array_get_utf8(ffi_array, 2);
359+
assert_eq!(vx_string::as_str(vx_str3), "test");
360+
vx_string_free(vx_str3);
381361

382362
vx_array_free(ffi_array);
383363
}
@@ -393,35 +373,17 @@ mod tests {
393373
]);
394374
let ffi_array = vx_array::new(binary_array.into_array());
395375

396-
let mut buffer = vec![0u8; 10];
397-
let mut len: c_int = 0;
376+
let vx_bin1 = vx_array_get_binary(ffi_array, 0);
377+
assert_eq!(vx_binary::as_slice(vx_bin1), &[0x01, 0x02, 0x03]);
378+
vx_binary_free(vx_bin1);
398379

399-
vx_array_get_binary(
400-
ffi_array,
401-
0,
402-
buffer.as_mut_ptr() as *mut c_void,
403-
&raw mut len,
404-
);
405-
assert_eq!(len, 3);
406-
assert_eq!(&buffer[..3], &[0x01, 0x02, 0x03]);
407-
408-
vx_array_get_binary(
409-
ffi_array,
410-
1,
411-
buffer.as_mut_ptr() as *mut c_void,
412-
&raw mut len,
413-
);
414-
assert_eq!(len, 2);
415-
assert_eq!(&buffer[..2], &[0xFF, 0xEE]);
416-
417-
vx_array_get_binary(
418-
ffi_array,
419-
2,
420-
buffer.as_mut_ptr() as *mut c_void,
421-
&raw mut len,
422-
);
423-
assert_eq!(len, 4);
424-
assert_eq!(&buffer[..4], &[0xAA, 0xBB, 0xCC, 0xDD]);
380+
let vx_bin2 = vx_array_get_binary(ffi_array, 1);
381+
assert_eq!(vx_binary::as_slice(vx_bin2), &[0xFF, 0xEE]);
382+
vx_binary_free(vx_bin2);
383+
384+
let vx_bin3 = vx_array_get_binary(ffi_array, 2);
385+
assert_eq!(vx_binary::as_slice(vx_bin3), &[0xAA, 0xBB, 0xCC, 0xDD]);
386+
vx_binary_free(vx_bin3);
425387

426388
vx_array_free(ffi_array);
427389
}

vortex-ffi/src/binary.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ffi::c_char;
5+
use std::slice;
6+
7+
use crate::arc_dyn_wrapper;
8+
9+
arc_dyn_wrapper!(
10+
/// Strings for use within Vortex.
11+
[u8],
12+
vx_binary
13+
);
14+
15+
impl vx_binary {
16+
#[allow(dead_code)]
17+
pub(crate) fn as_slice(ptr: *const vx_binary) -> &'static [u8] {
18+
unsafe { slice::from_raw_parts(vx_binary_ptr(ptr).cast(), vx_binary_len(ptr)) }
19+
}
20+
}
21+
22+
/// Create a new Vortex UTF-8 string by copying from a pointer and length.
23+
#[unsafe(no_mangle)]
24+
pub unsafe extern "C-unwind" fn vx_binary_new(ptr: *const c_char, len: usize) -> *const vx_binary {
25+
let slice = unsafe { slice::from_raw_parts(ptr.cast(), len) };
26+
vx_binary::new(slice.into())
27+
}
28+
29+
/// Return the length of the string in bytes.
30+
#[unsafe(no_mangle)]
31+
pub unsafe extern "C-unwind" fn vx_binary_len(ptr: *const vx_binary) -> usize {
32+
vx_binary::as_ref(ptr).len()
33+
}
34+
35+
/// Return the pointer to the string data.
36+
#[unsafe(no_mangle)]
37+
pub unsafe extern "C-unwind" fn vx_binary_ptr(ptr: *const vx_binary) -> *const c_char {
38+
vx_binary::as_ref(ptr).as_ptr().cast()
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[test]
46+
fn test_string_new() {
47+
unsafe {
48+
let test_str = "hello world";
49+
let ptr = test_str.as_ptr() as *const c_char;
50+
let len = test_str.len();
51+
52+
let vx_str = vx_binary_new(ptr, len);
53+
assert_eq!(vx_binary_len(vx_str), 11);
54+
assert_eq!(vx_binary::as_slice(vx_str), "hello world".as_bytes());
55+
56+
vx_binary_free(vx_str);
57+
}
58+
}
59+
60+
#[test]
61+
fn test_string_ptr() {
62+
unsafe {
63+
let test_str = "testing".as_bytes();
64+
let vx_str = vx_binary::new(test_str.into());
65+
66+
let ptr = vx_binary_ptr(vx_str);
67+
let len = vx_binary_len(vx_str);
68+
69+
let slice = slice::from_raw_parts(ptr as *const u8, len);
70+
assert_eq!(slice, "testing".as_bytes());
71+
72+
vx_binary_free(vx_str);
73+
}
74+
}
75+
76+
#[test]
77+
fn test_empty_string() {
78+
unsafe {
79+
let empty = "";
80+
let ptr = empty.as_ptr() as *const c_char;
81+
let vx_str = vx_binary_new(ptr, 0);
82+
83+
assert_eq!(vx_binary_len(vx_str), 0);
84+
assert_eq!(vx_binary::as_slice(vx_str), "".as_bytes());
85+
86+
vx_binary_free(vx_str);
87+
}
88+
}
89+
90+
#[test]
91+
fn test_unicode_string() {
92+
unsafe {
93+
let unicode_str = "Hello 世界 🌍";
94+
let ptr = unicode_str.as_ptr() as *const c_char;
95+
let len = unicode_str.len();
96+
97+
let vx_str = vx_binary_new(ptr, len);
98+
assert_eq!(vx_binary_len(vx_str), unicode_str.len());
99+
assert_eq!(vx_binary::as_slice(vx_str), unicode_str.as_bytes());
100+
101+
vx_binary_free(vx_str);
102+
}
103+
}
104+
}

vortex-ffi/src/dtype.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::ffi::{c_int, c_void};
4+
use std::ptr;
55
use std::sync::Arc;
66

77
use vortex::dtype::datetime::{DATE_ID, TIME_ID, TIMESTAMP_ID, TemporalMetadata};
88
use vortex::dtype::{DType, DecimalDType};
9-
use vortex::error::{VortexExpect, VortexUnwrap, vortex_panic};
9+
use vortex::error::{VortexExpect, vortex_panic};
1010

1111
use crate::arc_wrapper;
1212
use crate::ptype::vx_ptype;
13+
use crate::string::vx_string;
1314
use crate::struct_fields::vx_struct_fields;
1415

1516
arc_wrapper!(
@@ -286,13 +287,9 @@ pub unsafe extern "C-unwind" fn vx_dtype_time_unit(dtype: *const DType) -> u8 {
286287
metadata.as_ref()[0]
287288
}
288289

289-
/// Returns the time zone, assuming the type is time.
290+
/// Returns the time zone, assuming the type is time. Caller is responsible for freeing the returned pointer.
290291
#[unsafe(no_mangle)]
291-
pub unsafe extern "C-unwind" fn vx_dtype_time_zone(
292-
dtype: *const DType,
293-
dst: *mut c_void,
294-
len: *mut c_int,
295-
) {
292+
pub unsafe extern "C-unwind" fn vx_dtype_time_zone(dtype: *const DType) -> *const vx_string {
296293
let dtype = unsafe { dtype.as_ref() }.vortex_expect("dtype null");
297294

298295
let DType::Extension(ext_dtype) = dtype else {
@@ -302,13 +299,9 @@ pub unsafe extern "C-unwind" fn vx_dtype_time_zone(
302299
match TemporalMetadata::try_from(ext_dtype).vortex_expect("timestamp") {
303300
TemporalMetadata::Timestamp(_, zone) => {
304301
if let Some(zone) = zone {
305-
let bytes = zone.as_bytes();
306-
let dst = unsafe { std::slice::from_raw_parts_mut(dst as *mut u8, bytes.len()) };
307-
dst.copy_from_slice(bytes);
308-
unsafe { *len = bytes.len().try_into().vortex_unwrap() };
302+
vx_string::new(zone.into())
309303
} else {
310-
// No time zone, using local timestamps.
311-
unsafe { *len = 0 };
304+
ptr::null()
312305
}
313306
}
314307
_ => vortex_panic!("DType_time_zone: not a timestamp metadata: {ext_dtype:?}"),

vortex-ffi/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
mod array;
1010
mod array_iterator;
11+
mod binary;
1112
mod dtype;
1213
mod error;
1314
mod file;

vortex-ffi/src/string.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ arc_dyn_wrapper!(
1616

1717
impl vx_string {
1818
#[allow(dead_code)]
19-
pub(crate) fn as_str<'a>(ptr: *const vx_string) -> &'a str {
19+
pub(crate) fn as_str(ptr: *const vx_string) -> &'static str {
2020
unsafe {
2121
str::from_utf8_unchecked(slice::from_raw_parts(
2222
vx_string_ptr(ptr).cast(),

0 commit comments

Comments
 (0)