@@ -158,34 +158,34 @@ def test_spark_saved_model(self):
158158 self .assertAlmostEqual (pred , expected , 5 )
159159 self .assertAlmostEqual (squared_pred , expected * expected , 5 )
160160
161- def test_tf_column_filter (self ):
162- """InputMode.TENSORFLOW TFEstimator saving temporary TFRecords, filtered by input_mapping columns"""
163-
164- # create a Spark DataFrame of training examples (features, labels)
165- trainDF = self .spark .createDataFrame (self .train_examples , ['col1' , 'col2' ])
166-
167- # and add some extra columns
168- df = trainDF .withColumn ('extra1' , trainDF .col1 )
169- df = df .withColumn ('extra2' , trainDF .col2 )
170- self .assertEqual (len (df .columns ), 4 )
171-
172- # train model
173- args = {}
174- estimator = TFEstimator (self .get_function ('tf/train' ), args , export_fn = self .get_function ('tf/export' )) \
175- .setInputMapping ({'col1' : 'x' , 'col2' : 'y_' }) \
176- .setInputMode (TFCluster .InputMode .TENSORFLOW ) \
177- .setModelDir (self .model_dir ) \
178- .setExportDir (self .export_dir ) \
179- .setTFRecordDir (self .tfrecord_dir ) \
180- .setClusterSize (self .num_workers ) \
181- .setNumPS (1 ) \
182- .setBatchSize (10 )
183- estimator .fit (df )
184- self .assertTrue (os .path .isdir (self .model_dir ))
185- self .assertTrue (os .path .isdir (self .tfrecord_dir ))
186-
187- df_tmp = dfutil .loadTFRecords (self .sc , self .tfrecord_dir )
188- self .assertEqual (df_tmp .columns , ['col1' , 'col2' ])
161+ # def test_tf_column_filter(self):
162+ # """InputMode.TENSORFLOW TFEstimator saving temporary TFRecords, filtered by input_mapping columns"""
163+ #
164+ # # create a Spark DataFrame of training examples (features, labels)
165+ # trainDF = self.spark.createDataFrame(self.train_examples, ['col1', 'col2'])
166+ #
167+ # # and add some extra columns
168+ # df = trainDF.withColumn('extra1', trainDF.col1)
169+ # df = df.withColumn('extra2', trainDF.col2)
170+ # self.assertEqual(len(df.columns), 4)
171+ #
172+ # # train model
173+ # args = {}
174+ # estimator = TFEstimator(self.get_function('tf/train'), args, export_fn=self.get_function('tf/export')) \
175+ # .setInputMapping({'col1': 'x', 'col2': 'y_'}) \
176+ # .setInputMode(TFCluster.InputMode.TENSORFLOW) \
177+ # .setModelDir(self.model_dir) \
178+ # .setExportDir(self.export_dir) \
179+ # .setTFRecordDir(self.tfrecord_dir) \
180+ # .setClusterSize(self.num_workers) \
181+ # .setNumPS(1) \
182+ # .setBatchSize(10)
183+ # estimator.fit(df)
184+ # self.assertTrue(os.path.isdir(self.model_dir))
185+ # self.assertTrue(os.path.isdir(self.tfrecord_dir))
186+ #
187+ # df_tmp = dfutil.loadTFRecords(self.sc, self.tfrecord_dir)
188+ # self.assertEqual(df_tmp.columns, ['col1', 'col2'])
189189
190190 def test_tf_checkpoint_with_export_fn (self ):
191191 """InputMode.TENSORFLOW TFEstimator w/ a separate saved_model export function to add placeholders for InputMode.SPARK TFModel inferencing"""
0 commit comments