Skip to content

Commit 2091bb7

Browse files
committed
move example.SparkInferTF to tensorflowonspark.Inference; add --verbose option
1 parent cf8a6c3 commit 2091bb7

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/main/scala/com/yahoo/examples/SparkInferTF.scala renamed to src/main/scala/com/yahoo/tensorflowonspark/Inference.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
1-
package com.yahoo.examples
1+
package com.yahoo.tensorflowonspark
22

3-
import com.yahoo.tensorflowonspark.{DFUtil, SimpleTypeParser, TFModel}
3+
import org.apache.spark.sql.SparkSession
44
import org.apache.spark.sql.types._
55
import org.apache.spark.{SparkConf, SparkContext}
6-
import org.apache.spark.sql.SparkSession
76
import org.json4s._
87
import org.json4s.native.JsonMethods
98

10-
object SparkInferTF {
9+
/**
10+
* Spark application that infers from a TensorFlow SavedModel.
11+
*/
12+
object Inference {
1113

1214
case class Config(export_dir: String = "",
1315
input: String = "",
1416
schema_hint: StructType = new StructType(),
1517
input_mapping: Map[String, String] = Map.empty,
1618
output_mapping: Map[String, String] = Map.empty,
17-
output: String = "")
19+
output: String = "",
20+
verbose: Boolean = false)
1821

1922
def main(args: Array[String]) {
20-
val conf = new SparkConf().setAppName("SparkInferTF")
23+
val conf = new SparkConf().setAppName("Inference")
2124
implicit val sc: SparkContext = new SparkContext(conf)
22-
val parser = new scopt.OptionParser[Config]("SparkInferTF") {
25+
val parser = new scopt.OptionParser[Config]("Inference") {
2326
opt[String]("export_dir").text("Path to exported saved_model")
2427
.action((x, conf) => conf.copy(export_dir = x))
2528
opt[String]("input").text("Path to input TFRecords")
@@ -31,6 +34,7 @@ object SparkInferTF {
3134
opt[String]("output_mapping").text("JSON mapping of output tensors to output columns")
3235
.action((x, conf) => conf.copy(output_mapping = JsonMethods.parse(x).values.asInstanceOf[Map[String, String]]))
3336
opt[String]("output").text("Path to write predictions").action((x, conf) => conf.copy(output = x))
37+
opt[Unit]("verbose").text("Print input dataframe sample with schema").action((_, conf) => conf.copy(verbose = true))
3438
}
3539

3640
parser.parse(args, Config()) match {
@@ -46,8 +50,10 @@ object SparkInferTF {
4650

4751
// load TFRecords as a Spark DataFrame (using a user-provided schema hint)
4852
val df = DFUtil.loadTFRecords(config.input, config.schema_hint)
49-
df.show()
50-
df.printSchema()
53+
if (config.verbose) {
54+
df.show()
55+
df.printSchema()
56+
}
5157

5258
// instantiate a TFModel pointing to an existing TensorFlow saved_model export
5359
// set up mappings between input DataFrame columns to input Tensors

0 commit comments

Comments
 (0)