@@ -8,7 +8,10 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext}
88import org .apache .spark .sql .types ._
99import org .joda .time .format .DateTimeFormat
1010import org .apache .spark .ml .linalg .Vectors
11+ import org .apache .spark .ml .linalg .Vector
1112import org .apache .spark .ml .linalg .VectorUDT
13+ import org .apache .spark .sql .Dataset
14+ import org .apache .spark .ml .util ._
1215
1316/**
1417 * Loads various test datasets
@@ -28,21 +31,26 @@ object TestHelper {
2831 /**
2932 * @return the discretizer fit to the data given the specified features to bin and label use as target.
3033 */
31- def createSelectorModel (dataframe : DataFrame , inputCols : Array [String ],
34+
35+ def createSelectorModel (sqlContext : SQLContext , dataframe : Dataset [_], inputCols : Array [String ],
3236 labelColumn : String ,
3337 nPartitions : Int = 100 ,
3438 numTopFeatures : Int = 20 ,
3539 allVectorsDense : Boolean = true ): InfoThSelectorModel = {
3640 val featureAssembler = new VectorAssembler ()
3741 .setInputCols(inputCols)
3842 .setOutputCol(" features" )
39- val processedDf = featureAssembler.transform(dataframe)
40-
43+ val processedDf = featureAssembler.transform(dataframe).select(labelColumn + INDEX_SUFFIX , " features" )
4144
42- processedDf.map {
45+ /** InfoSelector requires all vectors from the same type (either be sparse or dense) **/
46+ val rddData = processedDf.rdd.map {
4347 case Row (label : Double , features : Vector ) =>
44- OldLabeledPoint (label, OldVectors .fromML(features))
48+ val standardv = if (allVectorsDense) features.toDense else features.toSparse
49+ Row .fromSeq(Seq (label, standardv))
4550 }
51+
52+ val inputData = sqlContext.createDataFrame(rddData, processedDf.schema)
53+
4654 val selector = new InfoThSelector ()
4755 .setSelectCriterion(" mrmr" )
4856 .setNPartitions(nPartitions)
@@ -51,20 +59,20 @@ object TestHelper {
5159 .setLabelCol(labelColumn + INDEX_SUFFIX )
5260 .setOutputCol(" selectedFeatures" )
5361
54- selector.fit(processedDf )
62+ selector.fit(inputData )
5563 }
5664
5765
5866 /**
5967 * The label column will have null values replaced with MISSING values in this case.
6068 * @return the discretizer fit to the data given the specified features to bin and label use as target.
6169 */
62- def getSelectorModel (dataframe : DataFrame , inputCols : Array [String ],
70+ def getSelectorModel (sqlContext : SQLContext , dataframe : DataFrame , inputCols : Array [String ],
6371 labelColumn : String ,
6472 nPartitions : Int = 100 ,
6573 numTopFeatures : Int = 20 ): InfoThSelectorModel = {
6674 val processedDf = cleanLabelCol(dataframe, labelColumn)
67- createSelectorModel(processedDf, inputCols, labelColumn, nPartitions, numTopFeatures)
75+ createSelectorModel(sqlContext, processedDf, inputCols, labelColumn, nPartitions, numTopFeatures)
6876 }
6977
7078
0 commit comments