Skip to content

Commit 41c9a2b

Browse files
authored
@W-8256900@ read via local filesystem addendum to #521 (#522)
* file system * wip * fix * wip
1 parent dcf6b67 commit 41c9a2b

File tree

4 files changed

+48
-42
lines changed

4 files changed

+48
-42
lines changed

core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@ package com.salesforce.op
3333

3434
import java.io.File
3535

36-
import com.salesforce.op.OpWorkflowModelReadWriteShared.{FieldNames => FN}
37-
import com.salesforce.op.OpWorkflowModelReadWriteShared.FieldNames._
3836
import com.salesforce.op.OpWorkflowModelReadWriteShared.DeprecatedFieldNames._
37+
import com.salesforce.op.OpWorkflowModelReadWriteShared.FieldNames._
38+
import com.salesforce.op.OpWorkflowModelReadWriteShared.{FieldNames => FN}
3939
import com.salesforce.op.features.{FeatureJsonHelper, OPFeature, TransientFeature}
4040
import com.salesforce.op.filters.{FeatureDistribution, RawFeatureFilterResults}
4141
import com.salesforce.op.stages.OpPipelineStageReaderWriter._
4242
import com.salesforce.op.stages._
43-
import org.zeroturnaround.zip.ZipUtil
4443
import org.apache.commons.io.IOUtils
4544
import org.apache.hadoop.conf.Configuration
4645
import org.apache.hadoop.fs.{FileSystem, Path}
4746
import org.apache.hadoop.io.compress.CompressionCodecFactory
4847
import org.json4s.JsonAST.{JArray, JNothing, JValue}
4948
import org.json4s.jackson.JsonMethods.parse
49+
import org.zeroturnaround.zip.ZipUtil
5050

5151
import scala.collection.mutable.ArrayBuffer
5252
import scala.io.Source
@@ -71,34 +71,39 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow], val asSpark: Bo
7171
*/
7272
final def load(path: String, modelStagingDir: String = WorkflowFileReader.modelStagingDir): OpWorkflowModel = {
7373
implicit val conf = new Configuration()
74-
val localPath = new Path(modelStagingDir)
7574
val localFileSystem = FileSystem.getLocal(conf)
75+
val localPath = localFileSystem.makeQualified(new Path(modelStagingDir))
7676
localFileSystem.delete(localPath, true)
7777

7878
val savePath = new Path(path)
7979
val remoteFileSystem = savePath.getFileSystem(conf)
80-
81-
val zipPath = new Path(localPath, WorkflowFileReader.zipModel)
82-
83-
remoteFileSystem.copyToLocalFile(savePath, zipPath)
84-
80+
val zipDir = new Path(localPath, WorkflowFileReader.zipModel)
81+
remoteFileSystem.copyToLocalFile(savePath, zipDir)
82+
83+
// New serialization:
84+
// remote: savePath (dir) -> Model.zip (file)
85+
// local: Model.zip (dir) -> Model.zip (file)
86+
// Old serialization:
87+
// remote: savePath (dir)
88+
// local: Model.zip (dir)
8589
val modelDir = new Path(localPath, WorkflowFileReader.rawModel)
86-
val fileToLoad = Try {
87-
val zipFile = new File(zipPath.toString)
88-
val subZip = // TODO figure out why it puts the files like this
89-
if (zipFile.isDirectory) new File(zipFile, WorkflowFileReader.zipModel)
90-
else zipFile
91-
ZipUtil.unpack(subZip, new File(modelDir.toString))
92-
} match { // For backwards compatibility since old models will not be zipped
93-
case Success(_) => modelDir.toString
94-
case Failure(_) => zipPath.toString
95-
}
96-
97-
val model = Try(WorkflowFileReader.loadFile(OpWorkflowModelReadWriteShared.jsonPath(fileToLoad)))
98-
.flatMap(loadJson(_, path = fileToLoad)) match {
99-
case Failure(error) => throw new RuntimeException(s"Failed to load Workflow from path '$path'", error)
90+
val modelPath = Try {
91+
localFileSystem.open(new Path(zipDir, WorkflowFileReader.zipModel))
92+
}.map { inputStream =>
93+
try {
94+
ZipUtil.unpack(inputStream, new File(modelDir.toUri.getPath))
95+
modelDir.toString
96+
} finally inputStream.close()
97+
}.getOrElse(zipDir.toString)
98+
99+
val model = Try(
100+
WorkflowFileReader.loadFile(OpWorkflowModelReadWriteShared.jsonPath(modelPath))
101+
).flatMap(loadJson(_, path = modelPath)) match {
102+
case Failure(error) =>
103+
throw new RuntimeException(s"Failed to load Workflow from path '$path'", error)
100104
case Success(wf) => wf
101105
}
106+
102107
localFileSystem.delete(localPath, true)
103108
model
104109
}
@@ -260,7 +265,6 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow], val asSpark: Bo
260265
}
261266

262267
private object WorkflowFileReader {
263-
264268
val rawModel = "rawModel"
265269
val zipModel = "Model.zip"
266270
def modelStagingDir: String = s"modelStagingDir/model-${System.currentTimeMillis}"
@@ -286,14 +290,16 @@ private object WorkflowFileReader {
286290
}
287291

288292
private def readAsString(path: Path)(implicit conf: Configuration): String = {
289-
val fs = path.getFileSystem(conf)
290293
val codecFactory = new CompressionCodecFactory(conf)
291294
val codec = Option(codecFactory.getCodec(path))
292-
val in = fs.open(path)
293-
val read = codec.map( c => Source.fromInputStream(c.createInputStream(in)).mkString )
294-
.getOrElse( IOUtils.toString(in, "UTF-8") )
295-
in.close()
296-
read
295+
val in = FileSystem.getLocal(conf).open(path)
296+
try {
297+
val read = codec.map(c => Source.fromInputStream(c.createInputStream(in)).mkString)
298+
.getOrElse(IOUtils.toString(in, "UTF-8"))
299+
read
300+
} finally {
301+
in.close()
302+
}
297303
}
298304
}
299305

core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import com.salesforce.op.stages.{OPStage, OpPipelineStageWriter}
3838
import com.salesforce.op.utils.spark.{JobGroupUtil, OpStep}
3939
import enumeratum._
4040
import org.apache.hadoop.conf.Configuration
41-
import org.apache.hadoop.fs.{Path, RawLocalFileSystem}
41+
import org.apache.hadoop.fs.{FileSystem, Path}
4242
import org.apache.hadoop.io.compress.GzipCodec
4343
import org.apache.spark.ml.util.MLWriter
4444
import org.json4s.JsonAST.{JArray, JObject, JString}
@@ -193,8 +193,6 @@ private[op] object OpWorkflowModelReadWriteShared {
193193
* Writes the OpWorkflowModel into a specified path
194194
*/
195195
object OpWorkflowModelWriter {
196-
val localFileSystem = new RawLocalFileSystem()
197-
198196
/**
199197
* Save [[OpWorkflowModel]] to path
200198
*
@@ -211,7 +209,7 @@ object OpWorkflowModelWriter {
211209
): Unit = {
212210
val localPath = new Path(modelStagingDir)
213211
val conf = new Configuration()
214-
212+
val localFileSystem = FileSystem.getLocal(conf)
215213
if (overwrite) localFileSystem.delete(localPath, true)
216214
val raw = new Path(modelStagingDir, WorkflowFileReader.rawModel)
217215

@@ -221,7 +219,7 @@ object OpWorkflowModelWriter {
221219
val compressed = new Path(modelStagingDir, WorkflowFileReader.zipModel)
222220
ZipUtil.pack(new File(raw.toString), new File(compressed.toString))
223221

224-
val finalPath = new Path(path)
222+
val finalPath = new Path(path, WorkflowFileReader.zipModel)
225223
val destinationFileSystem = finalPath.getFileSystem(conf)
226224
destinationFileSystem.moveFromLocalFile(compressed, finalPath)
227225
}

core/src/test/scala/com/salesforce/op/OpWorkflowRunnerTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class OpWorkflowRunnerTest extends AsyncFlatSpec with PassengerSparkFixtureTest
226226
outFile.exists shouldBe true
227227
val dirFile = if (outFile.getAbsolutePath.endsWith("/model")) {
228228
val unpacked = new File(outFile.getAbsolutePath + "Unpacked")
229-
ZipUtil.unpack(outFile, unpacked)
229+
ZipUtil.unpack(new File(outFile, "Model.zip"), unpacked)
230230
unpacked
231231
} else outFile
232232
dirFile.isDirectory shouldBe true

features/src/main/scala/com/salesforce/op/stages/SparkStageParam.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams
3434
import ml.combust.bundle.BundleFile
3535
import ml.combust.bundle.dsl.Bundle
3636
import ml.combust.bundle.serializer.SerializationFormat
37-
import ml.combust.mleap.runtime.MleapSupport._
37+
import ml.combust.mleap.runtime.MleapSupport.MleapBundleFileOps
3838
import ml.combust.mleap.runtime.frame.{Transformer => MLeapTransformer}
3939
import ml.combust.mleap.spark.SparkSupport._
4040
import ml.combust.mleap.xgboost.runtime.bundle.ops.{XGBoostClassificationOp, XGBoostRegressionOp}
41-
import org.apache.hadoop.fs.{Path, RawLocalFileSystem}
41+
import org.apache.hadoop.conf.Configuration
42+
import org.apache.hadoop.fs.{FileSystem, Path}
4243
import org.apache.spark.ml.bundle.SparkBundleContext
4344
import org.apache.spark.ml.param.{Param, ParamPair, Params}
4445
import org.apache.spark.ml.util.{Identifiable, MLReader, MLWritable}
@@ -62,8 +63,6 @@ class SparkStageParam[S <: PipelineStage with Params]
6263

6364
import SparkStageParam._
6465

65-
@transient val rawLocalFileSystem = new RawLocalFileSystem()
66-
6766
/**
6867
* Spark stage saving path
6968
*/
@@ -92,7 +91,8 @@ class SparkStageParam[S <: PipelineStage with Params]
9291
def json(className: String, uid: String) = compact(render(("className" -> className) ~ ("uid" -> uid)))
9392
(sparkStage, savePath, sbc) match {
9493
case (Some(stage), Some(path), Some(c)) =>
95-
val stagePath = rawLocalFileSystem.makeQualified(new Path(path, stage.uid))
94+
val stagePath = SparkStageParam.localFileSystem.makeQualified(new Path(path, stage.uid))
95+
9696
for {bundle <- managed(BundleFile(stagePath.toUri))} {
9797
stage.asInstanceOf[Transformer].writeBundle.format(SerializationFormat.Json).save(bundle)(c)
9898
.getOrElse(throw new RuntimeException(s"Failed to write $stage to $path with context $c"))
@@ -166,7 +166,7 @@ class SparkStageParam[S <: PipelineStage with Params]
166166
None
167167
case (Some(path), Some(stageUid), asSpark, className) =>
168168
savePath = Option(path)
169-
val stagePath = rawLocalFileSystem.makeQualified(new Path(path, stageUid))
169+
val stagePath = SparkStageParam.localFileSystem.makeQualified(new Path(path, stageUid))
170170
val loaded = for {bundle <- managed(BundleFile(stagePath.toUri))} yield {
171171
// TODO remove random forest regression when mleap spark deserialization is fixed
172172
// https://github.com/combust/mleap/issues/721
@@ -191,6 +191,8 @@ object SparkStageParam {
191191
val NoUID = ""
192192
val RandomForestRegressor = "org.apache.spark.ml.regression.RandomForestRegressionModel"
193193

194+
val localFileSystem = FileSystem.getLocal(new Configuration())
195+
194196
def updateParamsMetadataWithPath(jValue: JValue, path: String, asSpark: Boolean): JValue = jValue match {
195197
case JObject(pairs) => JObject(
196198
pairs.map {

0 commit comments

Comments
 (0)