Skip to content

Commit 8e09d44

Browse files
committed
wip arbitrary arrays
1 parent a05fa95 commit 8e09d44

File tree

3 files changed

+198
-102
lines changed

3 files changed

+198
-102
lines changed

matlab/rust/wkw_load/src/lib.rs

Lines changed: 33 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ extern crate zarrs;
44
#[macro_use]
55
extern crate wkw_mex;
66
use wkw_mex::*;
7-
use wkwrap::{Box3, Mat, Vec3, VoxelType};
87
use zarrs::array::data_type::DataType;
98
use zarrs::array::Array;
109
use zarrs::array_subset::ArraySubset;
@@ -40,55 +39,37 @@ mex_function!(nlhs, lhs, nrhs, rhs, {
4039
);
4140
let array = zarrs_result_to_str_error(Array::open(store.clone(), "/"))?;
4241

43-
let num_channels = array.shape()[0] as usize;
44-
let is_multi_channel = num_channels > 1;
42+
let array_shape = array.shape();
43+
let ndim = array_shape.len();
44+
let data_type = array.data_type();
45+
let type_size = if let Some(type_size) = data_type.fixed_size() {
46+
type_size
47+
} else {
48+
return Err("Unsupported data type".to_string());
49+
};
4550

4651
// build shape
47-
let bbox = mx_array_to_wkwrap_box(rhs[1])?;
52+
let (bbox_start, bbox_shape) = mx_array_to_bbox(rhs[1], ndim)?;
53+
54+
if bbox_start
55+
.iter()
56+
.zip(bbox_shape.iter())
57+
.zip(array_shape.iter())
58+
.any(|((bbox_min_x, bbox_shape_x), shape_x)| (*bbox_min_x + *bbox_shape_x) > *shape_x)
59+
{
60+
return Err(format!(
61+
"Bounding box start={:?}, shape={:?} is out of bounds for array of shape={:?}.",
62+
bbox_start, bbox_shape, array_shape
63+
));
64+
}
65+
println!("{:?} {:?}", bbox_start, bbox_shape);
4866
let subset = zarrs_result_to_str_error(ArraySubset::new_with_start_shape(
49-
vec![
50-
0,
51-
bbox.min().x as u64,
52-
bbox.min().y as u64,
53-
bbox.min().z as u64,
54-
],
55-
vec![
56-
1,
57-
bbox.width().x as u64,
58-
bbox.width().y as u64,
59-
bbox.width().z as u64,
60-
],
67+
bbox_start.clone(),
68+
bbox_shape.clone(),
6169
))?;
6270

63-
let shape_arr = [
64-
num_channels,
65-
bbox.width().x as usize,
66-
bbox.width().y as usize,
67-
bbox.width().z as usize,
68-
];
69-
let shape_slice = if is_multi_channel {
70-
&shape_arr[0..]
71-
} else {
72-
&shape_arr[1..]
73-
};
74-
7571
// prepare allocation
76-
let voxel_type = match array.data_type() {
77-
DataType::UInt8 => VoxelType::U8,
78-
DataType::UInt16 => VoxelType::U16,
79-
DataType::UInt32 => VoxelType::U32,
80-
DataType::UInt64 => VoxelType::U64,
81-
DataType::Float32 => VoxelType::F32,
82-
DataType::Float64 => VoxelType::F64,
83-
DataType::Int8 => VoxelType::I8,
84-
DataType::Int16 => VoxelType::I16,
85-
DataType::Int32 => VoxelType::I32,
86-
DataType::Int64 => VoxelType::I64,
87-
_ => {
88-
return Err("Unsupported data type".to_string());
89-
}
90-
};
91-
let class = match array.data_type() {
72+
let mat_class = match array.data_type() {
9273
DataType::UInt8 => MxClassId::Uint8,
9374
DataType::UInt16 => MxClassId::Uint16,
9475
DataType::UInt32 => MxClassId::Uint32,
@@ -106,24 +87,16 @@ mex_function!(nlhs, lhs, nrhs, rhs, {
10687

10788
// read data
10889
let data_all = zarrs_result_to_str_error(array.retrieve_array_subset(&subset))?;
109-
let mut buf = zarrs_result_to_str_error(data_all.into_fixed())?.into_owned(); // in c-order
110-
let src_mat = Mat::new(
111-
&mut buf,
112-
bbox.width(),
113-
voxel_type.size() * num_channels,
114-
voxel_type,
115-
true,
116-
)?;
117-
let arr = create_numeric_array(shape_slice, class, MxComplexity::Real)?;
118-
let mut mat = mx_array_mut_to_wkwrap_mat(is_multi_channel, arr)?;
119-
120-
src_mat.copy_as_fortran_order(
121-
&mut mat,
122-
Box3::new(Vec3 { x: 0, y: 0, z: 0 }, bbox.width())?,
123-
)?;
90+
let zarr_buf = zarrs_result_to_str_error(data_all.into_fixed())?.into_owned(); // in c-order
91+
92+
println!("zarr_buf: {:?}", zarr_buf);
93+
94+
let mat_arr = create_numeric_array(&bbox_shape, mat_class, MxComplexity::Real)?;
95+
96+
copy_as_fortran_order(&zarr_buf, mat_arr, &bbox_shape, type_size)?;
12497

12598
// set output
126-
lhs[0] = arr;
99+
lhs[0] = mat_arr;
127100

128101
Ok(())
129102
});

matlab/rust/wkw_mex/src/util.rs

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
use ::ffi::*;
1+
use ffi::*;
22

33
use std;
4-
use std::slice;
54
use std::ffi::CStr;
5+
use std::slice;
66

77
pub type Result<T> = std::result::Result<T, String>;
88

99
pub fn as_nat(f: f64) -> Result<u64> {
1010
if f <= 0.0 {
11-
return Err("Input must be positive".to_string())
11+
return Err("Input must be positive".to_string());
1212
}
1313

1414
match f % 1.0 == 0.0 {
1515
true => Ok(f as u64),
16-
false => Err("Input must be an integer".to_string())
16+
false => Err("Input must be an integer".to_string()),
1717
}
1818
}
1919

@@ -22,53 +22,57 @@ pub fn as_log2(f: f64) -> Result<u8> {
2222

2323
match i & (i - 1) == 0 {
2424
true => Ok(i.trailing_zeros() as u8),
25-
false => Err("Input must be a power of two".to_string())
25+
false => Err("Input must be a power of two".to_string()),
2626
}
2727
}
2828

2929
pub fn str_slice_to_mx_class_id(class_id: &str) -> Result<MxClassId> {
3030
match class_id {
31-
"uint8" => Ok(MxClassId::Uint8),
31+
"uint8" => Ok(MxClassId::Uint8),
3232
"uint16" => Ok(MxClassId::Uint16),
3333
"uint32" => Ok(MxClassId::Uint32),
3434
"uint64" => Ok(MxClassId::Uint64),
3535
"single" => Ok(MxClassId::Single),
3636
"double" => Ok(MxClassId::Double),
37-
"int8" => Ok(MxClassId::Int8),
38-
"int16" => Ok(MxClassId::Int16),
39-
"int32" => Ok(MxClassId::Int32),
40-
"int64" => Ok(MxClassId::Int64),
41-
_ => Err("Unknown MxClassId name".to_string())
37+
"int8" => Ok(MxClassId::Int8),
38+
"int16" => Ok(MxClassId::Int16),
39+
"int32" => Ok(MxClassId::Int32),
40+
"int64" => Ok(MxClassId::Int64),
41+
_ => Err("Unknown MxClassId name".to_string()),
4242
}
4343
}
4444

4545
pub fn mx_array_to_str<'a>(pm: MxArray) -> Result<&'a str> {
4646
let pm_ptr = unsafe { mxArrayToUTF8String(pm) };
4747

4848
if pm_ptr.is_null() {
49-
return Err("mxArrayToUTF8String returned null".to_string())
49+
return Err("mxArrayToUTF8String returned null".to_string());
5050
}
5151

5252
let pm_cstr = unsafe { CStr::from_ptr(pm_ptr) };
5353

5454
match pm_cstr.to_str() {
5555
Ok(pm_str) => Ok(pm_str),
56-
Err(_) => Err("mxArray contains invalid UTF-8 data".to_string())
56+
Err(_) => Err("mxArray contains invalid UTF-8 data".to_string()),
5757
}
5858
}
5959

6060
pub fn mx_array_to_f64_slice<'a>(pm: MxArray) -> Result<&'a [f64]> {
6161
unsafe {
62-
if !mxIsDouble(pm) { return Err("MxArray is not of class \"double\"".to_string()) };
63-
if mxIsComplex(pm) { return Err("MxArray is complex".to_string()) };
62+
if !mxIsDouble(pm) {
63+
return Err("MxArray is not of class \"double\"".to_string());
64+
};
65+
if mxIsComplex(pm) {
66+
return Err("MxArray is complex".to_string());
67+
};
6468
}
6569

6670
let pm_numel = unsafe { mxGetNumberOfElements(pm) };
6771
let pm_ptr = unsafe { mxGetPr(pm) };
6872

6973
match pm_ptr.is_null() {
7074
true => Err("MxArray does not contain real values".to_string()),
71-
false => Ok(unsafe { slice::from_raw_parts(pm_ptr, pm_numel) })
75+
false => Ok(unsafe { slice::from_raw_parts(pm_ptr, pm_numel) }),
7276
}
7377
}
7478

@@ -77,7 +81,7 @@ pub fn mx_array_to_f64(pm: MxArray) -> Result<f64> {
7781

7882
match pm_slice.len() {
7983
1 => Ok(pm_slice[0]),
80-
_ => Err("MxArray contains an invalid number of doubles".to_string())
84+
_ => Err("MxArray contains an invalid number of doubles".to_string()),
8185
}
8286
}
8387

@@ -113,25 +117,26 @@ pub fn mx_array_size_to_usize_slice<'a>(pm: MxArray) -> &'a [usize] {
113117
let ndims = unsafe { mxGetNumberOfDimensions(pm) };
114118
let dims = unsafe { mxGetDimensions(pm) };
115119

116-
unsafe {
117-
slice::from_raw_parts(dims, ndims as usize)
118-
}
120+
unsafe { slice::from_raw_parts(dims, ndims as usize) }
119121
}
120122

121123
pub fn create_numeric_array(
122-
dims: &[usize],
124+
dims: &[u64],
123125
class: MxClassId,
124-
complexity: MxComplexity
126+
complexity: MxComplexity,
125127
) -> Result<MxArrayMut> {
126128
let arr = unsafe {
127129
mxCreateNumericArray(
128-
dims.len() as size_t, dims.as_ptr(),
129-
class as c_int, complexity as c_int)
130+
dims.len() as size_t,
131+
dims.as_ptr() as *const usize,
132+
class as c_int,
133+
complexity as c_int,
134+
)
130135
};
131136

132137
match arr.is_null() {
133138
true => Err("Failed to create uninitialized numeric array".to_string()),
134-
false => Ok(arr)
139+
false => Ok(arr),
135140
}
136141
}
137142

@@ -140,7 +145,7 @@ pub fn malloc(n: usize) -> Result<&'static mut [u8]> {
140145

141146
match ptr.is_null() {
142147
true => Err("Failed to allocate memory".to_string()),
143-
false => Ok(unsafe { slice::from_raw_parts_mut(ptr, n) })
148+
false => Ok(unsafe { slice::from_raw_parts_mut(ptr, n) }),
144149
}
145150
}
146151

@@ -156,3 +161,63 @@ pub fn die(msg: &str) {
156161
// die
157162
unsafe { mexErrMsgTxt(buf.as_ptr()) }
158163
}
164+
165+
pub fn copy_as_fortran_order(
166+
in_buf: &[u8],
167+
out_arr: MxArrayMut,
168+
shape: &[u64],
169+
type_size: usize,
170+
) -> Result<()> {
171+
let total_elems: usize = shape.iter().product::<u64>() as usize;
172+
if in_buf.len() != total_elems * type_size {
173+
return Err(format!(
174+
"Length of input buffer does not match expected size {} != {}",
175+
in_buf.len(),
176+
total_elems,
177+
));
178+
}
179+
180+
let result = mx_array_mut_to_u8_slice_mut(out_arr)?;
181+
if result.len() != total_elems * type_size {
182+
return Err(format!(
183+
"Length of output array does not match expected size {} != {}",
184+
result.len(),
185+
total_elems,
186+
));
187+
}
188+
189+
// Compute F-order (column-major) strides
190+
let mut f_strides = vec![1u64; shape.len()];
191+
for i in 1..shape.len() {
192+
f_strides[i] = f_strides[i - 1] * shape[i - 1];
193+
}
194+
195+
// Multi-dimensional index for C-order iteration
196+
let mut idx = vec![0u64; shape.len()];
197+
198+
// Iterate over all elements in C-order (sequential read)
199+
for elem_idx in 0..total_elems {
200+
// Compute Fortran-order (column-major) offset
201+
let f_offset_elems: u64 = idx.iter().zip(&f_strides).map(|(&i, &s)| i * s).sum();
202+
let f_offset_bytes = f_offset_elems as usize * type_size;
203+
204+
// Sequential read from in_buf
205+
let src_offset_bytes = elem_idx * type_size;
206+
let src_slice = &in_buf[src_offset_bytes..src_offset_bytes + type_size];
207+
208+
// Scattered write to result
209+
result[f_offset_bytes..f_offset_bytes + type_size].copy_from_slice(src_slice);
210+
211+
// Increment multi-dimensional index (C-order)
212+
for d in (0..shape.len()).rev() {
213+
idx[d] += 1;
214+
if idx[d] < shape[d] {
215+
break;
216+
} else if d > 0 {
217+
idx[d] = 0;
218+
}
219+
}
220+
}
221+
222+
Ok(())
223+
}

0 commit comments

Comments
 (0)