Skip to content

Commit 64652f4

Browse files
committed
arrow zerocopy for read and write in object store
1 parent 8ff162a commit 64652f4

File tree

5 files changed

+77
-54
lines changed

5 files changed

+77
-54
lines changed

core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
import io.ray.api.ObjectRef;
2222
import io.ray.api.Ray;
2323
import io.ray.api.call.ActorCreator;
24+
import io.ray.api.placementgroup.PlacementGroup;
25+
import io.ray.runtime.object.ObjectRefImpl;
26+
2427
import java.util.Map;
2528
import java.util.List;
2629

27-
import io.ray.api.placementgroup.PlacementGroup;
28-
import io.ray.runtime.object.ObjectRefImpl;
2930
import org.apache.spark.executor.RayDPExecutor;
3031

3132
public class RayExecutorUtils {

core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.List;
2222
import scala.collection.JavaConverters._
2323

2424
import io.ray.runtime.generated.Common.Address
25+
import org.apache.arrow.vector.VectorSchemaRoot
2526

2627
import org.apache.spark.{Partition, SparkContext, TaskContext}
2728
import org.apache.spark.api.java.JavaSparkContext
@@ -37,15 +38,15 @@ class RayDatasetRDD(
3738
jsc: JavaSparkContext,
3839
@transient val objectIds: List[Array[Byte]],
3940
locations: List[Array[Byte]])
40-
extends RDD[Array[Byte]](jsc.sc, Nil) {
41+
extends RDD[VectorSchemaRoot](jsc.sc, Nil) {
4142

4243
override def getPartitions: Array[Partition] = {
4344
objectIds.asScala.zipWithIndex.map { case (k, i) =>
4445
new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
4546
}.toArray
4647
}
4748

48-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
49+
override def compute(split: Partition, context: TaskContext): Iterator[VectorSchemaRoot] = {
4950
val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
5051
ObjectStoreReader.getBatchesFromStream(ref)
5152
}

core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,26 @@
1818
package org.apache.spark.sql.raydp
1919

2020
import java.io.ByteArrayInputStream
21+
import java.nio.ByteBuffer
2122
import java.nio.channels.{Channels, ReadableByteChannel}
2223
import java.util.List
2324

25+
import scala.collection.JavaConverters._
26+
2427
import com.intel.raydp.shims.SparkShimLoader
28+
import org.apache.arrow.vector.VectorSchemaRoot
2529

26-
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
30+
import org.apache.spark.TaskContext
31+
import org.apache.spark.api.java.JavaRDD
2732
import org.apache.spark.raydp.RayDPUtils
2833
import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
2934
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
35+
import org.apache.spark.sql.catalyst.InternalRow
3036
import org.apache.spark.sql.catalyst.expressions.GenericRow
3137
import org.apache.spark.sql.execution.arrow.ArrowConverters
32-
import org.apache.spark.sql.types.{IntegerType, StructType}
38+
import org.apache.spark.sql.types._
39+
import org.apache.spark.sql.util.ArrowUtils
40+
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
3341

3442
object ObjectStoreReader {
3543
def createRayObjectRefDF(
@@ -40,17 +48,56 @@ object ObjectStoreReader {
4048
spark.createDataFrame(rdd, schema)
4149
}
4250

51+
def fromRootIterator(
52+
arrowRootIter: Iterator[VectorSchemaRoot],
53+
schema: StructType,
54+
timeZoneId: String): Iterator[InternalRow] = {
55+
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
56+
57+
new Iterator[InternalRow] {
58+
private var rowIter = if (arrowRootIter.hasNext) nextBatch() else Iterator.empty
59+
60+
override def hasNext: Boolean = rowIter.hasNext || {
61+
if (arrowRootIter.hasNext) {
62+
rowIter = nextBatch()
63+
true
64+
} else {
65+
false
66+
}
67+
}
68+
69+
override def next(): InternalRow = rowIter.next()
70+
71+
private def nextBatch(): Iterator[InternalRow] = {
72+
val root = arrowRootIter.next()
73+
val columns = root.getFieldVectors.asScala.map { vector =>
74+
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
75+
}.toArray
76+
77+
val batch = new ColumnarBatch(columns)
78+
batch.setNumRows(root.getRowCount)
79+
root.close()
80+
batch.rowIterator().asScala
81+
}
82+
}
83+
}
84+
4385
def RayDatasetToDataFrame(
4486
sparkSession: SparkSession,
4587
rdd: RayDatasetRDD,
46-
schema: String): DataFrame = {
47-
SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
88+
schemaString: String): DataFrame = {
89+
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
90+
val sqlContext = new SQLContext(sparkSession)
91+
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
92+
val resultRDD = JavaRDD.fromRDD(rdd).rdd.mapPartitions { it =>
93+
fromRootIterator(it, schema, timeZoneId)
94+
}
95+
sqlContext.internalCreateDataFrame(resultRDD.setName("arrow"), schema)
4896
}
4997

5098
def getBatchesFromStream(
51-
ref: Array[Byte]): Iterator[Array[Byte]] = {
52-
val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]])
53-
ArrowConverters.getBatchesFromStream(
54-
Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
99+
ref: Array[Byte]): Iterator[VectorSchemaRoot] = {
100+
val objectRef = RayDPUtils.readBinary(ref, classOf[VectorSchemaRoot])
101+
Iterator[VectorSchemaRoot](objectRef.get)
55102
}
56103
}

core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.raydp
1919

20-
2120
import java.io.ByteArrayOutputStream
2221
import java.util.{List, UUID}
2322
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
@@ -61,17 +60,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
6160
val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID())
6261

6362
def writeToRay(
64-
data: Array[Byte],
63+
root: VectorSchemaRoot,
6564
numRecords: Int,
6665
queue: ObjectRefHolder.Queue,
6766
ownerName: String): RecordBatch = {
68-
69-
var objectRef: ObjectRef[Array[Byte]] = null
67+
var objectRef: ObjectRef[VectorSchemaRoot] = null
7068
if (ownerName == "") {
71-
objectRef = Ray.put(data)
69+
objectRef = Ray.put(root)
7270
} else {
7371
var dataOwner: PyActorHandle = Ray.getActor(ownerName).get()
74-
objectRef = Ray.put(data, dataOwner)
72+
objectRef = Ray.put(root, dataOwner)
7573
}
7674

7775
// add the objectRef to the objectRefHolder to avoid reference GC
@@ -111,21 +109,15 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
111109
val root = VectorSchemaRoot.create(arrowSchema, allocator)
112110
val results = new ArrayBuffer[RecordBatch]()
113111

114-
val byteOut = new ByteArrayOutputStream()
115112
val arrowWriter = ArrowWriter.create(root)
116113
var numRecords: Int = 0
117114

118115
Utils.tryWithSafeFinally {
119116
while (batchIter.hasNext) {
120117
// reset the state
121118
numRecords = 0
122-
byteOut.reset()
123119
arrowWriter.reset()
124120

125-
// write out the schema meta data
126-
val writer = new ArrowStreamWriter(root, null, byteOut)
127-
writer.start()
128-
129121
// get the next record batch
130122
val nextBatch = batchIter.next()
131123

@@ -136,19 +128,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
136128

137129
// set the write record count
138130
arrowWriter.finish()
139-
// write out the record batch to the underlying out
140-
writer.writeBatch()
141-
142-
// get the wrote ByteArray and save to Ray ObjectStore
143-
val byteArray = byteOut.toByteArray
144-
results += writeToRay(byteArray, numRecords, queue, ownerName)
145-
// end writes footer to the output stream and doesn't clean any resources.
146-
// It could throw exception if the output stream is closed, so it should be
147-
// in the try block.
148-
writer.end()
131+
132+
// write and schema root directly and save to Ray ObjectStore
133+
results += writeToRay(root, numRecords, queue, ownerName)
149134
}
150135
arrowWriter.reset()
151-
byteOut.close()
152136
} {
153137
// If we close root and allocator in TaskCompletionListener, there could be a race
154138
// condition where the writer thread keeps writing to the VectorSchemaRoot while
@@ -173,7 +157,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
173157
/**
174158
* For test.
175159
*/
176-
def getRandomRef(): List[Array[Byte]] = {
160+
def getRandomRef(): List[VectorSchemaRoot] = {
177161

178162
df.queryExecution.toRdd.mapPartitions { _ =>
179163
Iterator(ObjectRefHolder.getRandom(uuid))
@@ -233,7 +217,7 @@ object ObjectStoreWriter {
233217
var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
234218
val numExecutors = executorIds.length
235219
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
236-
.get.asInstanceOf[ActorHandle[RayAppMaster]]
220+
.get.asInstanceOf[ActorHandle[RayAppMaster]]
237221
val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
238222
// Check if there is any restarted executors
239223
if (!restartedExecutors.isEmpty) {
@@ -251,8 +235,8 @@ object ObjectStoreWriter {
251235
val refs = new Array[ObjectRef[Array[Byte]]](numPartitions)
252236
val handles = executorIds.map {id =>
253237
Ray.getActor("raydp-executor-" + id)
254-
.get
255-
.asInstanceOf[ActorHandle[RayDPExecutor]]
238+
.get
239+
.asInstanceOf[ActorHandle[RayDPExecutor]]
256240
}
257241
val handlesMap = (executorIds zip handles).toMap
258242
val locations = RayExecutorUtils.getBlockLocations(
@@ -261,18 +245,15 @@ object ObjectStoreWriter {
261245
// TODO use getPreferredLocs, but we don't have a host ip to actor table now
262246
refs(i) = RayExecutorUtils.getRDDPartition(
263247
handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
264-
queue.add(refs(i))
265-
}
266-
for (i <- 0 until numPartitions) {
248+
queue.add(RayDPUtils.readBinary(refs(i).get(), classOf[VectorSchemaRoot]))
267249
results(i) = RayDPUtils.convert(refs(i)).getId.getBytes
268250
}
269251
results
270252
}
271-
272253
}
273254

274255
object ObjectRefHolder {
275-
type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]]
256+
type Queue = ConcurrentLinkedQueue[ObjectRef[VectorSchemaRoot]]
276257
private val dfToQueue = new ConcurrentHashMap[UUID, Queue]()
277258

278259
def getQueue(df: UUID): Queue = {
@@ -297,7 +278,7 @@ object ObjectRefHolder {
297278
queue.size()
298279
}
299280

300-
def getRandom(df: UUID): Array[Byte] = {
281+
def getRandom(df: UUID): VectorSchemaRoot = {
301282
val queue = checkQueueExists(df)
302283
val ref = RayDPUtils.convert(queue.peek())
303284
ref.get()

python/raydp/spark/dataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def _convert_blocks_to_dataframe(blocks):
237237
return df
238238

239239
def _convert_by_rdd(spark: sql.SparkSession,
240-
blocks: Dataset,
240+
blocks: List[ObjectRef],
241241
locations: List[bytes],
242242
schema: StructType) -> DataFrame:
243243
object_ids = [block.binary() for block in blocks]
@@ -269,14 +269,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
269269
schema = StructType()
270270
for field in arrow_schema:
271271
schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable)
272-
#TODO how to branch on type of block?
273-
sample = ray.get(blocks[0])
274-
if isinstance(sample, bytes):
275-
return _convert_by_rdd(spark, blocks, locations, schema)
276-
elif isinstance(sample, pa.Table):
277-
return _convert_by_udf(spark, blocks, locations, schema)
278-
else:
279-
raise RuntimeError("ray.to_spark only supports arrow type blocks")
272+
return _convert_by_rdd(spark, blocks, locations, schema)
280273

281274
if HAS_MLDATASET:
282275
class RecordBatch(_SourceShard):

0 commit comments

Comments
 (0)