Skip to content

Commit d8e24a9

Browse files
authored
fix(java): Spark writer properly handles string and array data (#4335)
We need to use setSafe() as the initial allocation might not be large enough to hold all of the string data in the InternalRow, and a realloc might be required. Also adds support for Spark `ArrayType`, and augments the existing roundtrip writer unit test to include an `array<string>` --------- Signed-off-by: Andrew Duffy <[email protected]>
1 parent 642d4d8 commit d8e24a9

File tree

3 files changed

+67
-14
lines changed

3 files changed

+67
-14
lines changed

java/testfiles/Cargo.lock

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

java/vortex-spark/src/main/java/dev/vortex/spark/write/VortexDataWriter.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import dev.vortex.relocated.org.apache.arrow.memory.RootAllocator;
99
import dev.vortex.relocated.org.apache.arrow.vector.*;
1010
import dev.vortex.relocated.org.apache.arrow.vector.VectorSchemaRoot;
11+
import dev.vortex.relocated.org.apache.arrow.vector.complex.ListVector;
1112
import dev.vortex.relocated.org.apache.arrow.vector.ipc.ArrowStreamWriter;
1213
import dev.vortex.spark.SparkTypes;
1314
import java.io.ByteArrayOutputStream;
@@ -20,6 +21,8 @@
2021
import java.util.List;
2122
import java.util.Map;
2223
import org.apache.spark.sql.catalyst.InternalRow;
24+
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
25+
import org.apache.spark.sql.catalyst.util.ArrayData;
2326
import org.apache.spark.sql.connector.write.DataWriter;
2427
import org.apache.spark.sql.connector.write.WriterCommitMessage;
2528
import org.apache.spark.sql.types.*;
@@ -185,7 +188,8 @@ private void writeBatch() throws IOException {
185188
/**
186189
* Populates an Arrow vector with a value from an InternalRow.
187190
*/
188-
private void populateVector(FieldVector vector, DataType dataType, InternalRow row, int fieldIndex, int rowIndex) {
191+
private void populateVector(
192+
FieldVector vector, DataType dataType, SpecializedGetters row, int fieldIndex, int rowIndex) {
189193
if (dataType instanceof BooleanType) {
190194
((BitVector) vector).set(rowIndex, row.getBoolean(fieldIndex) ? 1 : 0);
191195
} else if (dataType instanceof ByteType) {
@@ -203,12 +207,12 @@ private void populateVector(FieldVector vector, DataType dataType, InternalRow r
203207
} else if (dataType instanceof StringType) {
204208
UTF8String str = row.getUTF8String(fieldIndex);
205209
if (str != null) {
206-
((VarCharVector) vector).set(rowIndex, str.getBytes());
210+
((VarCharVector) vector).setSafe(rowIndex, str.getBytes());
207211
}
208212
} else if (dataType instanceof BinaryType) {
209213
byte[] bytes = row.getBinary(fieldIndex);
210214
if (bytes != null) {
211-
((VarBinaryVector) vector).set(rowIndex, bytes);
215+
((VarBinaryVector) vector).setSafe(rowIndex, bytes);
212216
}
213217
} else if (dataType instanceof DecimalType) {
214218
DecimalType decType = (DecimalType) dataType;
@@ -218,9 +222,19 @@ private void populateVector(FieldVector vector, DataType dataType, InternalRow r
218222
.toJavaBigDecimal();
219223
((DecimalVector) vector).set(rowIndex, decimal);
220224
}
225+
} else if (dataType instanceof ArrayType) {
226+
ArrayType arrayType = (ArrayType) dataType;
227+
ArrayData data = row.getArray(fieldIndex);
228+
ListVector listVector = ((ListVector) vector);
229+
int writtenElements = listVector.getElementEndIndex(listVector.getLastSet());
230+
listVector.startNewValue(rowIndex);
231+
for (int i = 0; i < data.numElements(); i++) {
232+
populateVector(listVector.getDataVector(), arrayType.elementType(), data, i, writtenElements + i);
233+
}
234+
listVector.endValue(rowIndex, data.numElements());
221235
} else {
222236
// For unsupported types, set null
223-
vector.setNull(rowIndex);
237+
throw new IllegalArgumentException("Unsupported data type: " + dataType);
224238
}
225239
}
226240

java/vortex-spark/src/test/java/dev/vortex/spark/VortexDataSourceWriteTest.java

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
import java.util.stream.Collectors;
1616
import java.util.stream.Stream;
1717
import org.apache.spark.sql.*;
18-
import org.apache.spark.sql.api.java.UDF1;
1918
import org.apache.spark.sql.types.DataTypes;
2019
import org.apache.spark.sql.types.StructField;
2120
import org.apache.spark.sql.types.StructType;
2221
import org.junit.jupiter.api.*;
23-
import org.junit.jupiter.api.Assumptions;
2422
import org.junit.jupiter.api.io.TempDir;
2523

2624
/**
@@ -68,7 +66,7 @@ public void testWriteAndReadVortexFiles() throws IOException {
6866

6967
// Verify original data
7068
assertEquals(numRows, originalDf.count(), "Original DataFrame should have " + numRows + " rows");
71-
assertEquals(2, originalDf.columns().length, "Original DataFrame should have 2 columns");
69+
assertEquals(3, originalDf.columns().length, "Original DataFrame should have 2 columns");
7270

7371
// When: Repartition to 2 partitions and write as Vortex
7472
Path outputPath = tempDir.resolve("vortex_output");
@@ -284,14 +282,12 @@ public void testSpecialCharactersAndNulls() throws IOException {
284282
* and their string representations.
285283
*/
286284
private Dataset<Row> createTestDataFrame(int numRows) {
287-
// Register UDF for integer to string conversion
288-
spark.udf().register("intToString", (UDF1<Integer, String>) value -> "value_" + value, DataTypes.StringType);
289-
290285
// Create DataFrame with monotonically increasing integers
291-
Dataset<Row> df = spark.range(0, numRows)
292-
.selectExpr("cast(id as int) as id", "concat('value_', cast(id as string)) as value");
293-
294-
return df;
286+
return spark.range(0, numRows)
287+
.selectExpr(
288+
"cast(id as int) as id",
289+
"concat('value_', cast(id as string)) as value",
290+
"array('Alpha', 'Bravo', 'Charlie') AS elements");
295291
}
296292

297293
/**

0 commit comments

Comments
 (0)