Skip to content

Commit 20d88e3

Browse files
authored
fix: FFI returns a string/binary instead of requiring the caller to guess the necessary memory (#5433)
fix: #5373 I think we could have these do RAII but we need something better than just scalar_at Signed-off-by: Robert Kruszewski <[email protected]> --------- Signed-off-by: Robert Kruszewski <[email protected]>
1 parent d72618d commit 20d88e3

File tree

6 files changed

+197
-102
lines changed

6 files changed

+197
-102
lines changed

vortex-ffi/cinclude/vortex.h

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,11 @@ typedef struct vx_array_iterator vx_array_iterator;
273273
*/
274274
typedef struct vx_array_sink vx_array_sink;
275275

276+
/**
277+
* Strings for use within Vortex.
278+
*/
279+
typedef struct vx_binary vx_binary;
280+
276281
/**
277282
* A Vortex data type.
278283
*
@@ -461,16 +466,18 @@ double vx_array_get_f64(const vx_array *array, uint32_t index);
461466
double vx_array_get_storage_f64(const vx_array *array, uint32_t index);
462467

463468
/**
464-
* Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
465-
* the length in `len`.
469+
* Return the utf-8 string at `index` in the array. The pointer will be null if the value at `index` is null.
470+
* The caller must free the returned pointer.
466471
*/
467-
void vx_array_get_utf8(const vx_array *array, uint32_t index, void *dst, int *len);
472+
const vx_string *vx_array_get_utf8(const vx_array *array,
473+
uint32_t index);
468474

469475
/**
470-
* Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
471-
* the length in `len`.
476+
* Return the binary at `index` in the array. The pointer will be null if the value at `index` is null.
477+
* The caller must free the returned pointer.
472478
*/
473-
void vx_array_get_binary(const vx_array *array, uint32_t index, void *dst, int *len);
479+
const vx_binary *vx_array_get_binary(const vx_array *array,
480+
uint32_t index);
474481

475482
/**
476483
* Free an owned [`vx_array_iterator`] object.
@@ -488,6 +495,34 @@ void vx_array_iterator_free(vx_array_iterator *ptr);
488495
const vx_array *vx_array_iterator_next(vx_array_iterator *iter,
489496
vx_error **error_out);
490497

498+
/**
499+
* Clone a borrowed [`vx_binary`], returning an owned [`vx_binary`].
500+
*
501+
*
502+
* Must be released with [`vx_binary_free`].
503+
*/
504+
const vx_binary *vx_binary_clone(const vx_binary *ptr);
505+
506+
/**
507+
* Free an owned [`vx_binary`] object.
508+
*/
509+
void vx_binary_free(const vx_binary *ptr);
510+
511+
/**
512+
* Create a new Vortex UTF-8 string by copying from a pointer and length.
513+
*/
514+
const vx_binary *vx_binary_new(const char *ptr, size_t len);
515+
516+
/**
517+
* Return the length of the string in bytes.
518+
*/
519+
size_t vx_binary_len(const vx_binary *ptr);
520+
521+
/**
522+
* Return the pointer to the string data.
523+
*/
524+
const char *vx_binary_ptr(const vx_binary *ptr);
525+
491526
/**
492527
* Clone a borrowed [`vx_dtype`], returning an owned [`vx_dtype`].
493528
*
@@ -629,9 +664,9 @@ bool vx_dtype_is_timestamp(const DType *dtype);
629664
uint8_t vx_dtype_time_unit(const DType *dtype);
630665

631666
/**
632-
* Returns the time zone, assuming the type is time.
667+
* Returns the time zone, assuming the type is time. Caller is responsible for freeing the returned pointer.
633668
*/
634-
void vx_dtype_time_zone(const DType *dtype, void *dst, int *len);
669+
const vx_string *vx_dtype_time_zone(const DType *dtype);
635670

636671
/**
637672
* Free an owned [`vx_error`] object.

vortex-ffi/src/array.rs

Lines changed: 41 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
//! FFI interface for working with Vortex Arrays.
5-
use std::ffi::{c_int, c_void};
6-
use std::slice;
5+
use std::ptr;
6+
use std::sync::Arc;
77

88
use vortex::dtype::half::f16;
9-
use vortex::error::{VortexExpect, VortexUnwrap, vortex_err};
9+
use vortex::error::{VortexExpect, vortex_err};
1010
use vortex::{Array, ToCanonical};
1111

1212
use crate::arc_dyn_wrapper;
13+
use crate::binary::vx_binary;
1314
use crate::dtype::vx_dtype;
1415
use crate::error::{try_or_default, vx_error};
16+
use crate::string::vx_string;
1517

1618
arc_dyn_wrapper!(
1719
/// Base type for all Vortex arrays.
@@ -133,48 +135,42 @@ ffiarray_get_ptype!(f16);
133135
ffiarray_get_ptype!(f32);
134136
ffiarray_get_ptype!(f64);
135137

136-
/// Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
137-
/// the length in `len`.
138+
/// Return the utf-8 string at `index` in the array. The pointer will be null if the value at `index` is null.
139+
/// The caller must free the returned pointer.
138140
#[unsafe(no_mangle)]
139141
pub unsafe extern "C-unwind" fn vx_array_get_utf8(
140142
array: *const vx_array,
141143
index: u32,
142-
dst: *mut c_void,
143-
len: *mut c_int,
144-
) {
144+
) -> *const vx_string {
145145
let array = vx_array::as_ref(array);
146146
let value = array.scalar_at(index as usize);
147147
let utf8_scalar = value.as_utf8();
148148
if let Some(buffer) = utf8_scalar.value() {
149-
let bytes = buffer.as_bytes();
150-
let dst = unsafe { slice::from_raw_parts_mut(dst as *mut u8, bytes.len()) };
151-
dst.copy_from_slice(bytes);
152-
unsafe { *len = bytes.len().try_into().vortex_unwrap() };
149+
vx_string::new(Arc::from(buffer.as_str()))
150+
} else {
151+
ptr::null()
153152
}
154153
}
155154

156-
/// Write the UTF-8 string at `index` in the array into the provided destination buffer, recording
157-
/// the length in `len`.
155+
/// Return the binary at `index` in the array. The pointer will be null if the value at `index` is null.
156+
/// The caller must free the returned pointer.
158157
#[unsafe(no_mangle)]
159158
pub unsafe extern "C-unwind" fn vx_array_get_binary(
160159
array: *const vx_array,
161160
index: u32,
162-
dst: *mut c_void,
163-
len: *mut c_int,
164-
) {
161+
) -> *const vx_binary {
165162
let array = vx_array::as_ref(array);
166163
let value = array.scalar_at(index as usize);
167-
let utf8_scalar = value.as_binary();
168-
if let Some(bytes) = utf8_scalar.value() {
169-
let dst = unsafe { slice::from_raw_parts_mut(dst as *mut u8, bytes.len()) };
170-
dst.copy_from_slice(&bytes);
171-
unsafe { *len = bytes.len().try_into().vortex_unwrap() };
164+
let binary_scalar = value.as_binary();
165+
if let Some(bytes) = binary_scalar.value() {
166+
vx_binary::new(Arc::from(bytes.as_bytes()))
167+
} else {
168+
ptr::null()
172169
}
173170
}
174171

175172
#[cfg(test)]
176173
mod tests {
177-
use std::ffi::{c_int, c_void};
178174
use std::ptr;
179175

180176
use vortex::IntoArray;
@@ -185,8 +181,10 @@ mod tests {
185181
use vortex::validity::Validity;
186182

187183
use crate::array::*;
184+
use crate::binary::vx_binary_free;
188185
use crate::dtype::{vx_dtype_get_variant, vx_dtype_variant};
189186
use crate::error::vx_error_free;
187+
use crate::string::vx_string_free;
190188

191189
#[test]
192190
fn test_simple() {
@@ -349,35 +347,17 @@ mod tests {
349347
let utf8_array = VarBinViewArray::from_iter_str(["hello", "world", "test"]);
350348
let ffi_array = vx_array::new(utf8_array.into_array());
351349

352-
let mut buffer = vec![0u8; 10];
353-
let mut len: c_int = 0;
350+
let vx_str1 = vx_array_get_utf8(ffi_array, 0);
351+
assert_eq!(vx_string::as_str(vx_str1), "hello");
352+
vx_string_free(vx_str1);
354353

355-
vx_array_get_utf8(
356-
ffi_array,
357-
0,
358-
buffer.as_mut_ptr() as *mut c_void,
359-
&raw mut len,
360-
);
361-
assert_eq!(len, 5);
362-
assert_eq!(&buffer[..5], b"hello");
363-
364-
vx_array_get_utf8(
365-
ffi_array,
366-
1,
367-
buffer.as_mut_ptr() as *mut c_void,
368-
&raw mut len,
369-
);
370-
assert_eq!(len, 5);
371-
assert_eq!(&buffer[..5], b"world");
372-
373-
vx_array_get_utf8(
374-
ffi_array,
375-
2,
376-
buffer.as_mut_ptr() as *mut c_void,
377-
&raw mut len,
378-
);
379-
assert_eq!(len, 4);
380-
assert_eq!(&buffer[..4], b"test");
354+
let vx_str2 = vx_array_get_utf8(ffi_array, 1);
355+
assert_eq!(vx_string::as_str(vx_str2), "world");
356+
vx_string_free(vx_str2);
357+
358+
let vx_str3 = vx_array_get_utf8(ffi_array, 2);
359+
assert_eq!(vx_string::as_str(vx_str3), "test");
360+
vx_string_free(vx_str3);
381361

382362
vx_array_free(ffi_array);
383363
}
@@ -393,35 +373,17 @@ mod tests {
393373
]);
394374
let ffi_array = vx_array::new(binary_array.into_array());
395375

396-
let mut buffer = vec![0u8; 10];
397-
let mut len: c_int = 0;
376+
let vx_bin1 = vx_array_get_binary(ffi_array, 0);
377+
assert_eq!(vx_binary::as_slice(vx_bin1), &[0x01, 0x02, 0x03]);
378+
vx_binary_free(vx_bin1);
398379

399-
vx_array_get_binary(
400-
ffi_array,
401-
0,
402-
buffer.as_mut_ptr() as *mut c_void,
403-
&raw mut len,
404-
);
405-
assert_eq!(len, 3);
406-
assert_eq!(&buffer[..3], &[0x01, 0x02, 0x03]);
407-
408-
vx_array_get_binary(
409-
ffi_array,
410-
1,
411-
buffer.as_mut_ptr() as *mut c_void,
412-
&raw mut len,
413-
);
414-
assert_eq!(len, 2);
415-
assert_eq!(&buffer[..2], &[0xFF, 0xEE]);
416-
417-
vx_array_get_binary(
418-
ffi_array,
419-
2,
420-
buffer.as_mut_ptr() as *mut c_void,
421-
&raw mut len,
422-
);
423-
assert_eq!(len, 4);
424-
assert_eq!(&buffer[..4], &[0xAA, 0xBB, 0xCC, 0xDD]);
380+
let vx_bin2 = vx_array_get_binary(ffi_array, 1);
381+
assert_eq!(vx_binary::as_slice(vx_bin2), &[0xFF, 0xEE]);
382+
vx_binary_free(vx_bin2);
383+
384+
let vx_bin3 = vx_array_get_binary(ffi_array, 2);
385+
assert_eq!(vx_binary::as_slice(vx_bin3), &[0xAA, 0xBB, 0xCC, 0xDD]);
386+
vx_binary_free(vx_bin3);
425387

426388
vx_array_free(ffi_array);
427389
}

vortex-ffi/src/binary.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ffi::c_char;
5+
use std::slice;
6+
7+
use crate::arc_dyn_wrapper;
8+
9+
arc_dyn_wrapper!(
10+
/// Strings for use within Vortex.
11+
[u8],
12+
vx_binary
13+
);
14+
15+
impl vx_binary {
16+
#[allow(dead_code)]
17+
pub(crate) fn as_slice(ptr: *const vx_binary) -> &'static [u8] {
18+
unsafe { slice::from_raw_parts(vx_binary_ptr(ptr).cast(), vx_binary_len(ptr)) }
19+
}
20+
}
21+
22+
/// Create a new Vortex UTF-8 string by copying from a pointer and length.
23+
#[unsafe(no_mangle)]
24+
pub unsafe extern "C-unwind" fn vx_binary_new(ptr: *const c_char, len: usize) -> *const vx_binary {
25+
let slice = unsafe { slice::from_raw_parts(ptr.cast(), len) };
26+
vx_binary::new(slice.into())
27+
}
28+
29+
/// Return the length of the string in bytes.
30+
#[unsafe(no_mangle)]
31+
pub unsafe extern "C-unwind" fn vx_binary_len(ptr: *const vx_binary) -> usize {
32+
vx_binary::as_ref(ptr).len()
33+
}
34+
35+
/// Return the pointer to the string data.
36+
#[unsafe(no_mangle)]
37+
pub unsafe extern "C-unwind" fn vx_binary_ptr(ptr: *const vx_binary) -> *const c_char {
38+
vx_binary::as_ref(ptr).as_ptr().cast()
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[test]
46+
fn test_string_new() {
47+
unsafe {
48+
let test_str = "hello world";
49+
let ptr = test_str.as_ptr() as *const c_char;
50+
let len = test_str.len();
51+
52+
let vx_str = vx_binary_new(ptr, len);
53+
assert_eq!(vx_binary_len(vx_str), 11);
54+
assert_eq!(vx_binary::as_slice(vx_str), "hello world".as_bytes());
55+
56+
vx_binary_free(vx_str);
57+
}
58+
}
59+
60+
#[test]
61+
fn test_string_ptr() {
62+
unsafe {
63+
let test_str = "testing".as_bytes();
64+
let vx_str = vx_binary::new(test_str.into());
65+
66+
let ptr = vx_binary_ptr(vx_str);
67+
let len = vx_binary_len(vx_str);
68+
69+
let slice = slice::from_raw_parts(ptr as *const u8, len);
70+
assert_eq!(slice, "testing".as_bytes());
71+
72+
vx_binary_free(vx_str);
73+
}
74+
}
75+
76+
#[test]
77+
fn test_empty_string() {
78+
unsafe {
79+
let empty = "";
80+
let ptr = empty.as_ptr() as *const c_char;
81+
let vx_str = vx_binary_new(ptr, 0);
82+
83+
assert_eq!(vx_binary_len(vx_str), 0);
84+
assert_eq!(vx_binary::as_slice(vx_str), "".as_bytes());
85+
86+
vx_binary_free(vx_str);
87+
}
88+
}
89+
90+
#[test]
91+
fn test_unicode_string() {
92+
unsafe {
93+
let unicode_str = "Hello 世界 🌍";
94+
let ptr = unicode_str.as_ptr() as *const c_char;
95+
let len = unicode_str.len();
96+
97+
let vx_str = vx_binary_new(ptr, len);
98+
assert_eq!(vx_binary_len(vx_str), unicode_str.len());
99+
assert_eq!(vx_binary::as_slice(vx_str), unicode_str.as_bytes());
100+
101+
vx_binary_free(vx_str);
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)