Skip to content

Commit 4ec0f68

Browse files
a10yrobert3005
andauthored
fix: handling of struct fields in Spark (#5453)
Fixes some issues I encountered trying to get some files with nested structs/lists to load. I'm still hitting a very strange issue deep in C Arrow Data interface with the 3.5GB github archives dataset, but this fixes some other issues. --------- Signed-off-by: Andrew Duffy <[email protected]> Signed-off-by: Robert Kruszewski <[email protected]> Co-authored-by: Robert Kruszewski <[email protected]>
1 parent 52ee607 commit 4ec0f68

File tree

8 files changed

+100
-24
lines changed

8 files changed

+100
-24
lines changed

java/settings.gradle.kts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@ rootProject.name = "vortex-root"
2020
// API bindings
2121
include("vortex-jni")
2222
include("vortex-spark")
23-

java/vortex-spark/build.gradle.kts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ tasks.withType<ShadowJar> {
9191
}
9292
}
9393

94-
9594
tasks.withType<Test>().all {
9695
classpath +=
9796
project(":vortex-jni")

java/vortex-spark/src/main/java/dev/vortex/spark/SparkTypes.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
package dev.vortex.spark;
55

6-
import com.google.common.collect.Streams;
76
import dev.vortex.api.DType;
87
import java.util.Optional;
98
import org.apache.spark.sql.connector.catalog.Column;
@@ -107,13 +106,20 @@ public static DataType toDataType(DType dType) {
107106
return DataTypes.BinaryType;
108107
case STRUCT:
109108
// For each of the inner struct fields, we capture them together here.
110-
var struct = new StructType();
109+
var fieldNames = dType.getFieldNames();
110+
var fieldTypes = dType.getFieldTypes();
111111

112-
Streams.forEachPair(
113-
dType.getFieldNames().stream(),
114-
dType.getFieldTypes().stream(),
115-
(name, type) -> struct.add(name, toDataType(type)));
116-
return struct;
112+
// NOTE: it's very important we do this with a for loop. Using the streams API can easily
113+
// lead to StackOverflowError being thrown.
114+
var fields = new StructField[fieldNames.size()];
115+
for (int i = 0; i < fieldNames.size(); i++) {
116+
var name = fieldNames.get(i);
117+
try (var type = fieldTypes.get(i)) {
118+
fields[i] = new StructField(name, toDataType(type), dType.isNullable(), Metadata.empty());
119+
}
120+
}
121+
122+
return DataTypes.createStructType(fields);
117123
case LIST:
118124
return DataTypes.createArrayType(toDataType(dType.getElementType()), dType.isNullable());
119125
case EXTENSION:
@@ -151,10 +157,17 @@ public static DataType toDataType(DType dType) {
151157
* Convert a STRUCT Vortex type to a Spark {@link Column}.
152158
*/
153159
public static Column[] toColumns(DType dType) {
154-
return Streams.zip(dType.getFieldNames().stream(), dType.getFieldTypes().stream(), (name, fieldType) -> {
155-
var dataType = toDataType(fieldType);
156-
return Column.create(name, dataType, fieldType.isNullable());
157-
})
158-
.toArray(Column[]::new);
160+
var fieldNames = dType.getFieldNames();
161+
var fieldTypes = dType.getFieldTypes();
162+
var columns = new Column[fieldNames.size()];
163+
164+
for (int i = 0; i < columns.length; i++) {
165+
var name = fieldNames.get(i);
166+
try (var type = fieldTypes.get(i)) {
167+
columns[i] = Column.create(name, toDataType(type), type.isNullable());
168+
}
169+
}
170+
171+
return columns;
159172
}
160173
}

java/vortex-spark/src/test/java/dev/vortex/spark/VortexDataSourceBasicTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import static org.junit.jupiter.api.Assertions.*;
77

8+
import dev.vortex.relocated.org.apache.arrow.vector.types.pojo.ArrowType;
89
import org.apache.spark.sql.types.DataTypes;
910
import org.apache.spark.sql.types.StructField;
1011
import org.apache.spark.sql.types.StructType;
@@ -48,6 +49,48 @@ public void testSparkToArrowSchemaConversion() {
4849
assertEquals("active", arrowSchema.getFields().get(3).getName());
4950
}
5051

52+
@Test
53+
@DisplayName("SparkToArrowSchema should convert nested types")
54+
public void testNestedSparkToArrowSchemaConversion() {
55+
// Create a more complex spark schema
56+
StructType sparkSchema = DataTypes.createStructType(new StructField[] {
57+
DataTypes.createStructField(
58+
"inner",
59+
DataTypes.createStructType(new StructField[] {
60+
DataTypes.createStructField("id", DataTypes.IntegerType, false),
61+
DataTypes.createStructField("name", DataTypes.StringType, true),
62+
DataTypes.createStructField("value", DataTypes.DoubleType, false),
63+
DataTypes.createStructField("active", DataTypes.BooleanType, true)
64+
}),
65+
false)
66+
});
67+
68+
// Convert to Arrow schema
69+
var arrowSchema = dev.vortex.spark.write.SparkToArrowSchema.convert(sparkSchema);
70+
71+
// Verify conversion
72+
assertNotNull(arrowSchema, "Arrow schema should not be null");
73+
assertEquals(1, arrowSchema.getFields().size(), "Arrow schema should have same number of fields");
74+
75+
// Should contain the right inner fields
76+
var nestedFields = arrowSchema.getFields().get(0).getChildren();
77+
78+
// Verify field types are preserved
79+
assertInstanceOf(ArrowType.Struct.class, arrowSchema.getFields().get(0).getType());
80+
81+
assertEquals("id", nestedFields.get(0).getName());
82+
assertInstanceOf(ArrowType.Int.class, nestedFields.get(0).getType());
83+
84+
assertEquals("name", nestedFields.get(1).getName());
85+
assertInstanceOf(ArrowType.Utf8.class, nestedFields.get(1).getType());
86+
87+
assertEquals("value", nestedFields.get(2).getName());
88+
assertInstanceOf(ArrowType.FloatingPoint.class, nestedFields.get(2).getType());
89+
90+
assertEquals("active", nestedFields.get(3).getName());
91+
assertInstanceOf(ArrowType.Bool.class, nestedFields.get(3).getType());
92+
}
93+
5194
@Test
5295
@DisplayName("VortexWriterCommitMessage should store metadata correctly")
5396
public void testWriterCommitMessage() {

vortex-dtype/src/field_names.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,27 @@ impl From<&[&'static str]> for FieldNames {
322322
}
323323
}
324324

325+
impl<const N: usize> From<[&str; N]> for FieldNames {
326+
fn from(value: [&str; N]) -> Self {
327+
Self(value.iter().cloned().map(FieldName::from).collect())
328+
}
329+
}
330+
331+
impl From<Vec<&str>> for FieldNames {
332+
fn from(value: Vec<&str>) -> Self {
333+
Self(value.into_iter().map(FieldName::from).collect())
334+
}
335+
}
336+
325337
impl From<&[FieldName]> for FieldNames {
326338
fn from(value: &[FieldName]) -> Self {
327339
Self(Arc::from(value))
328340
}
329341
}
330342

331-
impl<const N: usize> From<[&'static str; N]> for FieldNames {
332-
fn from(value: [&'static str; N]) -> Self {
333-
Self(value.into_iter().map(FieldName::from).collect())
343+
impl<const N: usize> From<&[&str; N]> for FieldNames {
344+
fn from(value: &[&str; N]) -> Self {
345+
Self(value.iter().cloned().map(FieldName::from).collect())
334346
}
335347
}
336348

vortex-jni/src/array.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ fn data_type_no_views(data_type: DataType) -> DataType {
120120
DataType::LargeList(FieldRef::new(new_inner))
121121
}
122122
DataType::Struct(fields) => {
123-
// Things
124123
let viewless_fields: Vec<FieldRef> = fields
125124
.iter()
126125
.map(|field_ref| {

vortex-jni/src/dtype.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getFieldTypes(
128128
let dtype = unsafe { &*(dtype_ptr as *const DType) };
129129

130130
try_or_throw(&mut env, |env| {
131-
let array_list = env.new_object("java/util/ArrayList", "()V", &[])?;
131+
let array_list = env
132+
.new_object("java/util/ArrayList", "()V", &[])
133+
.map_err(|e| JNIError::Vortex(vortex_err!("failure constructing ArrayList: {e}")))?;
132134
let field_types = env.get_list(&array_list)?;
133135
let Some(struct_dtype) = dtype.as_struct_fields_opt() else {
134136
throw_runtime!("DType should be STRUCT, was {dtype}");

vortex-jni/src/errors.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl JNIDefault for jobject {
7676
}
7777
}
7878

79-
/// Run the provided function inside of the JNIEnv context. Throws an exception if the function returns an error.
79+
/// Run the provided function inside the JNIEnv context. Throws an exception if the function returns an error.
8080
#[allow(clippy::expect_used)]
8181
#[inline]
8282
pub fn try_or_throw<'a, F, T>(env: &mut JNIEnv<'a>, function: F) -> T
@@ -87,12 +87,21 @@ where
8787
match function(env) {
8888
Ok(result) => result,
8989
Err(error) => {
90+
// Propagate the exception instead of throwing our own.
91+
if env
92+
.exception_check()
93+
.expect("checking exception should succeed")
94+
{
95+
return T::jni_default();
96+
}
97+
9098
let msg = error.to_string();
91-
env.throw((RUNTIME_EXC_CLASS, msg))
92-
.expect("throwing exception back to Java failed, everything is bad");
99+
match env.throw(msg) {
100+
Ok(()) => {}
101+
Err(err) => log::warn!("Failed throwing exception back up to Java: {err}"),
102+
}
103+
93104
T::jni_default()
94105
}
95106
}
96107
}
97-
98-
pub static RUNTIME_EXC_CLASS: &str = "java/lang/RuntimeException";

0 commit comments

Comments
 (0)