1717
1818package org .apache .spark .sql .raydp
1919
20-
2120import java .io .ByteArrayOutputStream
2221import java .util .{List , UUID }
2322import java .util .concurrent .{ConcurrentHashMap , ConcurrentLinkedQueue }
@@ -61,17 +60,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
6160 val uuid : UUID = ObjectStoreWriter .dfToId.getOrElseUpdate(df, UUID .randomUUID())
6261
6362 def writeToRay (
64- data : Array [ Byte ] ,
63+ root : VectorSchemaRoot ,
6564 numRecords : Int ,
6665 queue : ObjectRefHolder .Queue ,
6766 ownerName : String ): RecordBatch = {
68-
69- var objectRef : ObjectRef [Array [Byte ]] = null
67+ var objectRef : ObjectRef [VectorSchemaRoot ] = null
7068 if (ownerName == " " ) {
71- objectRef = Ray .put(data )
69+ objectRef = Ray .put(root )
7270 } else {
7371 var dataOwner : PyActorHandle = Ray .getActor(ownerName).get()
74- objectRef = Ray .put(data , dataOwner)
72+ objectRef = Ray .put(root , dataOwner)
7573 }
7674
7775 // add the objectRef to the objectRefHolder to avoid reference GC
@@ -111,21 +109,15 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
111109 val root = VectorSchemaRoot .create(arrowSchema, allocator)
112110 val results = new ArrayBuffer [RecordBatch ]()
113111
114- val byteOut = new ByteArrayOutputStream ()
115112 val arrowWriter = ArrowWriter .create(root)
116113 var numRecords : Int = 0
117114
118115 Utils .tryWithSafeFinally {
119116 while (batchIter.hasNext) {
120117 // reset the state
121118 numRecords = 0
122- byteOut.reset()
123119 arrowWriter.reset()
124120
125- // write out the schema meta data
126- val writer = new ArrowStreamWriter (root, null , byteOut)
127- writer.start()
128-
129121 // get the next record batch
130122 val nextBatch = batchIter.next()
131123
@@ -136,19 +128,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
136128
137129 // set the write record count
138130 arrowWriter.finish()
139- // write out the record batch to the underlying out
140- writer.writeBatch()
141-
142- // get the wrote ByteArray and save to Ray ObjectStore
143- val byteArray = byteOut.toByteArray
144- results += writeToRay(byteArray, numRecords, queue, ownerName)
145- // end writes footer to the output stream and doesn't clean any resources.
146- // It could throw exception if the output stream is closed, so it should be
147- // in the try block.
148- writer.end()
131+
132+ // write and schema root directly and save to Ray ObjectStore
133+ results += writeToRay(root, numRecords, queue, ownerName)
149134 }
150135 arrowWriter.reset()
151- byteOut.close()
152136 } {
153137 // If we close root and allocator in TaskCompletionListener, there could be a race
154138 // condition where the writer thread keeps writing to the VectorSchemaRoot while
@@ -173,7 +157,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
173157 /**
174158 * For test.
175159 */
176- def getRandomRef (): List [Array [ Byte ] ] = {
160+ def getRandomRef (): List [VectorSchemaRoot ] = {
177161
178162 df.queryExecution.toRdd.mapPartitions { _ =>
179163 Iterator (ObjectRefHolder .getRandom(uuid))
@@ -233,7 +217,7 @@ object ObjectStoreWriter {
233217 var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
234218 val numExecutors = executorIds.length
235219 val appMasterHandle = Ray .getActor(RayAppMaster .ACTOR_NAME )
236- .get.asInstanceOf [ActorHandle [RayAppMaster ]]
220+ .get.asInstanceOf [ActorHandle [RayAppMaster ]]
237221 val restartedExecutors = RayAppMasterUtils .getRestartedExecutors(appMasterHandle)
238222 // Check if there is any restarted executors
239223 if (! restartedExecutors.isEmpty) {
@@ -251,8 +235,8 @@ object ObjectStoreWriter {
251235 val refs = new Array [ObjectRef [Array [Byte ]]](numPartitions)
252236 val handles = executorIds.map {id =>
253237 Ray .getActor(" raydp-executor-" + id)
254- .get
255- .asInstanceOf [ActorHandle [RayDPExecutor ]]
238+ .get
239+ .asInstanceOf [ActorHandle [RayDPExecutor ]]
256240 }
257241 val handlesMap = (executorIds zip handles).toMap
258242 val locations = RayExecutorUtils .getBlockLocations(
@@ -261,18 +245,15 @@ object ObjectStoreWriter {
261245 // TODO use getPreferredLocs, but we don't have a host ip to actor table now
262246 refs(i) = RayExecutorUtils .getRDDPartition(
263247 handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
264- queue.add(refs(i))
265- }
266- for (i <- 0 until numPartitions) {
248+ queue.add(RayDPUtils .readBinary(refs(i).get(), classOf [VectorSchemaRoot ]))
267249 results(i) = RayDPUtils .convert(refs(i)).getId.getBytes
268250 }
269251 results
270252 }
271-
272253}
273254
274255object ObjectRefHolder {
275- type Queue = ConcurrentLinkedQueue [ObjectRef [Array [ Byte ] ]]
256+ type Queue = ConcurrentLinkedQueue [ObjectRef [VectorSchemaRoot ]]
276257 private val dfToQueue = new ConcurrentHashMap [UUID , Queue ]()
277258
278259 def getQueue (df : UUID ): Queue = {
@@ -297,7 +278,7 @@ object ObjectRefHolder {
297278 queue.size()
298279 }
299280
300- def getRandom (df : UUID ): Array [ Byte ] = {
281+ def getRandom (df : UUID ): VectorSchemaRoot = {
301282 val queue = checkQueueExists(df)
302283 val ref = RayDPUtils .convert(queue.peek())
303284 ref.get()
0 commit comments