1- package com .yahoo .examples
1+ package com .yahoo .tensorflowonspark
22
3- import com . yahoo . tensorflowonspark .{ DFUtil , SimpleTypeParser , TFModel }
3+ import org . apache . spark . sql . SparkSession
44import org .apache .spark .sql .types ._
55import org .apache .spark .{SparkConf , SparkContext }
6- import org .apache .spark .sql .SparkSession
76import org .json4s ._
87import org .json4s .native .JsonMethods
98
10- object SparkInferTF {
9+ /**
10+ * Spark application that infers from a TensorFlow SavedModel.
11+ */
12+ object Inference {
1113
1214 case class Config (export_dir : String = " " ,
1315 input : String = " " ,
1416 schema_hint : StructType = new StructType (),
1517 input_mapping : Map [String , String ] = Map .empty,
1618 output_mapping : Map [String , String ] = Map .empty,
17- output : String = " " )
19+ output : String = " " ,
20+ verbose : Boolean = false )
1821
1922 def main (args : Array [String ]) {
20- val conf = new SparkConf ().setAppName(" SparkInferTF " )
23+ val conf = new SparkConf ().setAppName(" Inference " )
2124 implicit val sc : SparkContext = new SparkContext (conf)
22- val parser = new scopt.OptionParser [Config ](" SparkInferTF " ) {
25+ val parser = new scopt.OptionParser [Config ](" Inference " ) {
2326 opt[String ](" export_dir" ).text(" Path to exported saved_model" )
2427 .action((x, conf) => conf.copy(export_dir = x))
2528 opt[String ](" input" ).text(" Path to input TFRecords" )
@@ -31,6 +34,7 @@ object SparkInferTF {
3134 opt[String ](" output_mapping" ).text(" JSON mapping of output tensors to output columns" )
3235 .action((x, conf) => conf.copy(output_mapping = JsonMethods .parse(x).values.asInstanceOf [Map [String , String ]]))
3336 opt[String ](" output" ).text(" Path to write predictions" ).action((x, conf) => conf.copy(output = x))
37+ opt[Unit ](" verbose" ).text(" Print input dataframe sample with schema" ).action((_, conf) => conf.copy(verbose = true ))
3438 }
3539
3640 parser.parse(args, Config ()) match {
@@ -46,8 +50,10 @@ object SparkInferTF {
4650
4751 // load TFRecords as a Spark DataFrame (using a user-provided schema hint)
4852 val df = DFUtil .loadTFRecords(config.input, config.schema_hint)
49- df.show()
50- df.printSchema()
53+ if (config.verbose) {
54+ df.show()
55+ df.printSchema()
56+ }
5157
5258 // instantiate a TFModel pointing to an existing TensorFlow saved_model export
5359 // set up mappings between input DataFrame columns to input Tensors
0 commit comments