3131package com .salesforce .op
3232
3333import java .io .File
34+ import java .nio .charset .StandardCharsets
3435
3536import com .salesforce .op .features .FeatureJsonHelper
3637import com .salesforce .op .filters .RawFeatureFilterResults
3738import com .salesforce .op .stages .{OPStage , OpPipelineStageWriter }
38- import com .salesforce .op .utils .spark .{JobGroupUtil , OpStep }
3939import enumeratum ._
4040import org .apache .hadoop .conf .Configuration
4141import org .apache .hadoop .fs .{FileSystem , Path }
42- import org .apache .hadoop .io .compress .GzipCodec
4342import org .apache .spark .ml .util .MLWriter
4443import org .json4s .JsonAST .{JArray , JObject , JString }
4544import org .json4s .JsonDSL ._
@@ -59,11 +58,39 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
5958
6059 implicit val jsonFormats : Formats = DefaultFormats
6160
61+ protected var modelStagingDir : String = WorkflowFileReader .modelStagingDir
62+
63+ /**
64+ * Set the local folder to copy and unpack stored model to for loading
65+ */
66+ def setModelStagingDir (localDir : String ): this .type = {
67+ modelStagingDir = localDir
68+ this
69+ }
70+
6271 override protected def saveImpl (path : String ): Unit = {
63- JobGroupUtil .withJobGroup(OpStep .ModelIO ) {
64- sc.parallelize(Seq (toJsonString(path)), 1 )
65- .saveAsTextFile(OpWorkflowModelReadWriteShared .jsonPath(path), classOf [GzipCodec ])
66- }(this .sparkSession)
72+ val conf = new Configuration ()
73+ val localFileSystem = FileSystem .getLocal(conf)
74+ val localPath = localFileSystem.makeQualified(new Path (modelStagingDir))
75+ localFileSystem.delete(localPath, true )
76+ val raw = new Path (localPath, WorkflowFileReader .rawModel)
77+
78+ val rawPathStr = raw.toString
79+ val modelJson = toJsonString(rawPathStr)
80+ val jsonPath = OpWorkflowModelReadWriteShared .jsonPath(rawPathStr)
81+ val os = localFileSystem.create(new Path (jsonPath))
82+ try {
83+ os.write(modelJson.getBytes(StandardCharsets .UTF_8 .toString))
84+ } finally {
85+ os.close()
86+ }
87+
88+ val compressed = new Path (localPath, WorkflowFileReader .zipModel)
89+ ZipUtil .pack(new File (raw.toUri.getPath), new File (compressed.toUri.getPath))
90+
91+ val finalPath = new Path (path, WorkflowFileReader .zipModel)
92+ val destinationFileSystem = finalPath.getFileSystem(conf)
93+ destinationFileSystem.moveFromLocalFile(compressed, finalPath)
6794 }
6895
6996 /**
@@ -207,21 +234,9 @@ object OpWorkflowModelWriter {
207234 overwrite : Boolean = true ,
208235 modelStagingDir : String = WorkflowFileReader .modelStagingDir
209236 ): Unit = {
210- val localPath = new Path (modelStagingDir)
211- val conf = new Configuration ()
212- val localFileSystem = FileSystem .getLocal(conf)
213- if (overwrite) localFileSystem.delete(localPath, true )
214- val raw = new Path (modelStagingDir, WorkflowFileReader .rawModel)
215-
216- val w = new OpWorkflowModelWriter (model)
237+ val w = new OpWorkflowModelWriter (model).setModelStagingDir(modelStagingDir)
217238 val writer = if (overwrite) w.overwrite() else w
218- writer.save(raw.toString)
219- val compressed = new Path (modelStagingDir, WorkflowFileReader .zipModel)
220- ZipUtil .pack(new File (raw.toString), new File (compressed.toString))
221-
222- val finalPath = new Path (path, WorkflowFileReader .zipModel)
223- val destinationFileSystem = finalPath.getFileSystem(conf)
224- destinationFileSystem.moveFromLocalFile(compressed, finalPath)
239+ writer.save(path)
225240 }
226241
227242 /**
0 commit comments