Skip to content

Commit 89c3b00

Browse files
committed
All test from Peng's page has been implemented and passed w/o errors.
1 parent ed272db commit 89c3b00

File tree

2 files changed

+68
-46
lines changed

2 files changed

+68
-46
lines changed

src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import TestHelper._
88

99

1010
/**
11-
* Test infomartion theoretic feature selection
11+
* Test information theoretic feature selection on datasets from Peng's webpage
1212
*
1313
* @author Sergio Ramirez
1414
*/
@@ -21,20 +21,78 @@ class ITSelectorSuite extends FunSuite with BeforeAndAfterAll {
2121
sqlContext = new SQLContext(SPARK_CTX)
2222
}
2323

24-
/** Do entropy based binning of cars data from UC Irvine repository. */
24+
/** Do mRMR feature selection on COLON data. */
2525
test("Run ITFS on colon data (nPart = 10, nfeat = 10)") {
2626

27-
val df = readColonData(sqlContext)
27+
val df = readCSVData(sqlContext, "test_colon_s3.csv")
2828
val cols = df.columns
2929
val pad = 2
3030
val allVectorsDense = true
31-
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 10, 10, allVectorsDense, pad)
31+
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
32+
10, 10, allVectorsDense, pad)
3233

3334
assertResult("512, 764, 1324, 1380, 1411, 1422, 1581, 1670, 1671, 1971") {
3435
model.selectedFeatures.mkString(", ")
3536
}
3637
}
3738

39+
/** Do mRMR feature selection on LEUKEMIA data. */
40+
test("Run ITFS on leukemia data (nPart = 10, nfeat = 10)") {
41+
42+
val df = readCSVData(sqlContext, "test_leukemia_s3.csv")
43+
val cols = df.columns
44+
val pad = 2
45+
val allVectorsDense = true
46+
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
47+
10, 10, allVectorsDense, pad)
48+
49+
assertResult("1084, 1719, 1774, 1822, 2061, 2294, 3192, 4387, 4787, 6795") {
50+
model.selectedFeatures.mkString(", ")
51+
}
52+
}
3853

54+
/** Do mRMR feature selection on LUNG data. */
55+
test("Run ITFS on lung data (nPart = 10, nfeat = 10)") {
56+
57+
val df = readCSVData(sqlContext, "test_lung_s3.csv")
58+
val cols = df.columns
59+
val pad = 2
60+
val allVectorsDense = true
61+
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
62+
10, 10, allVectorsDense, pad)
63+
64+
assertResult("18, 22, 29, 125, 132, 150, 166, 242, 243, 269") {
65+
model.selectedFeatures.mkString(", ")
66+
}
67+
}
68+
69+
/** Do mRMR feature selection on LYMPHOMA data. */
70+
test("Run ITFS on lymphoma data (nPart = 10, nfeat = 10)") {
3971

72+
val df = readCSVData(sqlContext, "test_lymphoma_s3.csv")
73+
val cols = df.columns
74+
val pad = 2
75+
val allVectorsDense = true
76+
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
77+
10, 10, allVectorsDense, pad)
78+
79+
assertResult("236, 393, 759, 2747, 2818, 2841, 2862, 3014, 3702, 3792") {
80+
model.selectedFeatures.mkString(", ")
81+
}
82+
}
83+
84+
/** Do mRMR feature selection on NCI data. */
85+
test("Run ITFS on nci data (nPart = 10, nfeat = 10)") {
86+
87+
val df = readCSVData(sqlContext, "test_nci9_s3.csv")
88+
val cols = df.columns
89+
val pad = 2
90+
val allVectorsDense = true
91+
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
92+
10, 10, allVectorsDense, pad)
93+
94+
assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") {
95+
model.selectedFeatures.mkString(", ")
96+
}
97+
}
4098
}

src/test/scala/org/apache/spark/ml/feature/TestHelper.scala

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ object TestHelper {
2929
final val INDEX_SUFFIX: String = "_IDX"
3030

3131
/**
32-
* @return the discretizer fit to the data given the specified features to bin and label use as target.
32+
* @return the feature select fit to the data given the specified features to bin and label use as target.
3333
*/
3434

3535
def createSelectorModel(sqlContext: SQLContext, dataframe: Dataset[_], inputCols: Array[String],
3636
labelColumn: String,
3737
nPartitions: Int = 100,
3838
numTopFeatures: Int = 20,
3939
allVectorsDense: Boolean = true,
40-
padded: Int = 0): InfoThSelectorModel = {
40+
padded: Int = 0 /* if minimum value is negative */): InfoThSelectorModel = {
4141
val featureAssembler = new VectorAssembler()
4242
.setInputCols(inputCols)
4343
.setOutputCol("features")
@@ -73,7 +73,7 @@ object TestHelper {
7373

7474
/**
7575
* The label column will have null values replaced with MISSING values in this case.
76-
* @return the discretizer fit to the data given the specified features to bin and label use as target.
76+
* @return the feature selector fit to the data given the specified features to bin and label use as target.
7777
*/
7878
def getSelectorModel(sqlContext: SQLContext, dataframe: DataFrame, inputCols: Array[String],
7979
labelColumn: String,
@@ -121,53 +121,17 @@ object TestHelper {
121121
sc
122122
}
123123

124-
/** @return standard iris dataset from UCI repo.
125-
*/
126-
/*def readColonData(sqlContext: SQLContext): DataFrame = {
127-
val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data")
128-
val nullable = true
129-
130-
val schema = (0 until 9712).map(i => StructField("var" + i, DoubleType, nullable)).toList :+
131-
StructField("colontype", StringType, nullable)
132-
// ints and dates must be read as doubles
133-
val rows = data.map(line => line.split(",").map(elem => elem.trim))
134-
.map(x => {Row.fromSeq(Seq(asDouble(x(0)), asDouble(x(1)), asDouble(x(2)), asDouble(x(3)), asString(x(4))))})
135-
136-
sqlContext.createDataFrame(rows, schema)
137-
}
138-
139-
/** @return standard iris dataset from UCI repo.
124+
/** @return standard csv data from the repo.
140125
*/
141-
def readColonData2(sqlContext: SQLContext): DataFrame = {
142-
val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data")
143-
val nullable = true
144-
val schema = StructType(List(
145-
StructField("features", new VectorUDT, nullable),
146-
StructField("class", DoubleType, nullable)
147-
))
148-
val rows = data.map{line =>
149-
val split = line.split(",").map(elem => elem.trim)
150-
val features = Vectors.dense(split.drop(1).map(_.toDouble))
151-
val label = split.head.toDouble
152-
(features, label)
153-
}
154-
val asd = sqlContext.createDataFrame(rows, schema)
155-
156-
}*/
157-
158-
159-
def readColonData(sqlContext: SQLContext): DataFrame = {
126+
def readCSVData(sqlContext: SQLContext, file: String): DataFrame = {
160127
val df = sqlContext.read
161128
.format("com.databricks.spark.csv")
162129
.option("header", "true") // Use first line of all files as header
163130
.option("inferSchema", "true") // Automatically infer data types
164-
.load(FILE_PREFIX + "test_colon_s3.csv")
131+
.load(FILE_PREFIX + file)
165132
df
166133
}
167-
168134

169-
170-
171135
/** @return dataset with 3 double columns. The first is the label column and contain null.
172136
*/
173137
def readNullLabelTestData(sqlContext: SQLContext): DataFrame = {

0 commit comments

Comments
 (0)