Skip to content

Commit eba38a0

Browse files
authored
put feature feature corr behind flag (#479)
1 parent 93e1fde commit eba38a0

File tree

4 files changed

+64
-22
lines changed

4 files changed

+64
-22
lines changed

core/src/main/scala/com/salesforce/op/dsl/RichNumericFeature.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ package com.salesforce.op.dsl
3333
import com.salesforce.op.features.FeatureLike
3434
import com.salesforce.op.features.types._
3535
import com.salesforce.op.stages.impl.feature._
36-
import com.salesforce.op.stages.impl.preparators.{CorrelationExclusion, CorrelationType, SanityChecker}
36+
import com.salesforce.op.stages.impl.preparators.{CorrelationExclusion, CorrelationLevel, CorrelationType, SanityChecker}
3737
import com.salesforce.op.stages.impl.regression.IsotonicRegressionCalibrator
3838

3939
import scala.language.postfixOps
@@ -483,7 +483,7 @@ trait RichNumericFeature {
483483
protectTextSharedHash: Boolean = SanityChecker.ProtectTextSharedHash,
484484
maxRuleConfidence: Double = SanityChecker.MaxRuleConfidence,
485485
minRequiredRuleSupport: Double = SanityChecker.MinRequiredRuleSupport,
486-
featureLabelCorrOnly: Boolean = SanityChecker.FeatureLabelCorrOnly,
486+
featureFeatureCorrLevel: CorrelationLevel = SanityChecker.FeatureFeatureCorrLevel,
487487
correlationExclusion: CorrelationExclusion = SanityChecker.CorrelationExclusionDefault,
488488
categoricalLabel: Option[Boolean] = None
489489
): FeatureLike[OPVector] = {
@@ -504,7 +504,7 @@ trait RichNumericFeature {
504504
.setProtectTextSharedHash(protectTextSharedHash)
505505
.setMaxRuleConfidence(maxRuleConfidence)
506506
.setMinRequiredRuleSupport(minRequiredRuleSupport)
507-
.setFeatureLabelCorrOnly(featureLabelCorrOnly)
507+
.setFeatureFeatureCorrLevel(featureFeatureCorrLevel)
508508
.setCorrelationExclusion(correlationExclusion)
509509
.setInput(f, featureVector)
510510

core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityChecker.scala

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,21 @@ trait SanityCheckerParams extends DerivedFeatureFilterParams {
177177
def setMinRequiredRuleSupport(value: Double): this.type = set(minRequiredRuleSupport, value)
178178
def getMinRequiredRuleSupport: Double = $(minRequiredRuleSupport)
179179

180-
final val featureLabelCorrOnly = new BooleanParam(
181-
parent = this, name = "featureLabelCorrOnly",
182-
doc = "If true, then only calculate the correlations between the features and the label. Otherwise, calculate " +
183-
"the entire correlation matrix, which includes all feature-feature correlations."
180+
final val featureFeatureCorrLevel = new Param[String](
181+
parent = this, name = "featureFeatureCorrOnly",
182+
doc = "This setting determines feature-feature correlation computations. Levels are: Off, Computed, Stored"
184183
)
185-
def setFeatureLabelCorrOnly(value: Boolean): this.type = set(featureLabelCorrOnly, value)
186-
def getFeatureLabelCorrOnly: Boolean = $(featureLabelCorrOnly)
184+
def setFeatureFeatureCorrLevel(value: CorrelationLevel): this.type = set(featureFeatureCorrLevel, value.entryName)
185+
def getFeatureFeatureCorrLevel: CorrelationLevel = CorrelationLevel.withName($(featureFeatureCorrLevel))
186+
187+
@deprecated("this setting is overridden by featureFeatureCorrLevel", "0.7.0")
188+
def setFeatureLabelCorrOnly(value: Boolean): this.type = {
189+
if (value) set(featureFeatureCorrLevel, CorrelationLevel.Off.entryName)
190+
else set(featureFeatureCorrLevel, CorrelationLevel.Computed.entryName)
191+
}
192+
193+
@deprecated("this setting is overridden by featureFeatureCorrLevel", "0.7.0")
194+
def getFeatureLabelCorrOnly: Boolean = $(featureFeatureCorrLevel) == CorrelationLevel.Off.entryName
187195

188196
final val correlationExclusion: Param[String] = new Param[String](this, "correlationExclusion",
189197
"Setting for what categories of feature vector columns to exclude from the correlation calculation",
@@ -208,7 +216,7 @@ trait SanityCheckerParams extends DerivedFeatureFilterParams {
208216
correlationType -> SanityChecker.CorrelationTypeDefault.entryName,
209217
maxRuleConfidence -> SanityChecker.MaxRuleConfidence,
210218
minRequiredRuleSupport -> SanityChecker.MinRequiredRuleSupport,
211-
featureLabelCorrOnly -> SanityChecker.FeatureLabelCorrOnly,
219+
featureFeatureCorrLevel -> SanityChecker.FeatureFeatureCorrLevel.entryName,
212220
correlationExclusion -> SanityChecker.CorrelationExclusionDefault.entryName
213221
)
214222
}
@@ -453,7 +461,7 @@ class SanityChecker(uid: String = UID[SanityChecker])
453461
else ((0 until featureSize + 1).toArray, vectorRows)
454462
val numCorrIndices = corrIndices.length
455463

456-
val (corrMatrix, corrsWithLabel) = if ($(featureLabelCorrOnly)) {
464+
val (corrMatrix, corrsWithLabel) = if ($(featureFeatureCorrLevel) == CorrelationLevel.Off.entryName) {
457465
None -> OpStatistics.computeCorrelationsWithLabel(vectorRowsForCorr, colStats, count)
458466
}
459467
else {
@@ -513,7 +521,8 @@ class SanityChecker(uid: String = UID[SanityChecker])
513521
colStats = colStats,
514522
names = featureNames :+ in1.name,
515523
correlationType = CorrelationType.withNameInsensitive(corrType),
516-
sample = sampleFraction
524+
sample = sampleFraction,
525+
keepFeatureFeature = getFeatureFeatureCorrLevel
517526
)
518527
setMetadata(outputMeta.toMetadata.withSummaryMetadata(summary.toMetadata()))
519528

@@ -565,7 +574,7 @@ object SanityChecker {
565574
// These settings will make the maxRuleConfidence check off by default
566575
val MaxRuleConfidence = 1.0
567576
val MinRequiredRuleSupport = 1.0
568-
val FeatureLabelCorrOnly = false
577+
val FeatureFeatureCorrLevel = CorrelationLevel.Computed
569578
val CorrelationExclusionDefault = CorrelationExclusion.NoExclusion
570579

571580
def SampleSeed: Long = util.Random.nextLong() // scalastyle:off method.name
@@ -620,3 +629,28 @@ object CorrelationExclusion extends Enum[CorrelationExclusion] {
620629
*/
621630
case object HashedText extends CorrelationExclusion
622631
}
632+
633+
634+
/**
635+
* Settings for feature - feature correlations
636+
*/
637+
sealed trait CorrelationLevel extends EnumEntry with Serializable
638+
639+
object CorrelationLevel extends Enum[CorrelationLevel] {
640+
val values: Seq[CorrelationLevel] = findValues
641+
642+
/**
643+
* Feature-feature correlations are off
644+
*/
645+
case object Off extends CorrelationLevel
646+
647+
/**
648+
* Feature-feature correlations computed for feature exclusion
649+
*/
650+
case object Computed extends CorrelationLevel
651+
652+
/**
653+
* Feature-feature correlations stored in metadata
654+
*/
655+
case object Stored extends CorrelationLevel
656+
}

core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerMetadata.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,17 @@ case class SanityCheckerSummary
8989
colStats: MultivariateStatisticalSummary,
9090
names: Seq[String],
9191
correlationType: CorrelationType,
92-
sample: Double
92+
sample: Double,
93+
keepFeatureFeature: CorrelationLevel
9394
) {
9495
this(
9596
correlations = new Correlations(
96-
stats.filter(s => s.corrLabel.isDefined).map(s => (s.name, s.corrLabel.get, s.featureCorrs)),
97+
stats.filter(s => s.corrLabel.isDefined).map { s =>
98+
keepFeatureFeature match {
99+
case CorrelationLevel.Stored => (s.name, s.corrLabel.get, s.featureCorrs)
100+
case _ => (s.name, s.corrLabel.get, Seq.empty)
101+
}
102+
},
97103
correlationType
98104
),
99105
dropped = dropped,
@@ -288,7 +294,7 @@ case class Correlations
288294
def this(corrs: Seq[(String, Double, Seq[Double])], corrType: CorrelationType) = this(
289295
featuresIn = corrs.map(_._1),
290296
valuesWithLabel = corrs.map(_._2),
291-
valuesWithFeatures = corrs.map(_._3),
297+
valuesWithFeatures = if (corrs.flatMap(_._3).isEmpty) Seq.empty else corrs.map(_._3),
292298
corrType = corrType
293299
)
294300

core/src/test/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerTest.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
532532
featuresToDrop, featuresWithNaNCorr, featuresIngore)
533533
}
534534

535-
it should "only calculate correlations between feature and the label if requested" in {
535+
it should "only store correlations between feature and the label if requested" in {
536536
val smartMapVectorized = new SmartTextMapVectorizer[TextMap]()
537537
.setMaxCardinality(2).setNumFeatures(8).setMinSupport(1).setTopK(2).setPrependFeatureName(true)
538538
.setHashSpaceStrategy(HashSpaceStrategy.Shared)
@@ -544,7 +544,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
544544
.setRemoveBadFeatures(true)
545545
.setRemoveFeatureGroup(true)
546546
.setProtectTextSharedHash(true)
547-
.setFeatureLabelCorrOnly(true)
547+
.setFeatureFeatureCorrLevel(CorrelationLevel.Stored)
548548
.setMinCorrelation(0.0)
549549
.setMaxCorrelation(0.8)
550550
.setMaxCramersV(0.8)
@@ -564,7 +564,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
564564

565565
val expectedFeatNames = featuresWithCorr ++ featuresWithNaNCorr
566566
validateTransformerOutput(checkedFeatures.name, transformed, expectedFeatNames,
567-
featuresToDrop, featuresWithNaNCorr)
567+
featuresToDrop, featuresWithNaNCorr, hasFeatureFeature = true)
568568
}
569569

570570
it should "not fail when calculating feature-label correlations on a 5k element feature vector" in {
@@ -584,7 +584,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
584584
.setRemoveBadFeatures(false)
585585
.setRemoveFeatureGroup(true)
586586
.setProtectTextSharedHash(true)
587-
.setFeatureLabelCorrOnly(true)
587+
.setFeatureFeatureCorrLevel(CorrelationLevel.Off)
588588
.setMinVariance(-0.1)
589589
.setMinCorrelation(0.0)
590590
.setMaxCorrelation(0.8)
@@ -611,7 +611,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
611611

612612
val expectedFeatNames = featuresWithCorr ++ featuresWithNaNCorr
613613
validateTransformerOutput(checkedFeatures.name, transformed, expectedFeatNames,
614-
featuresToDrop, featuresWithNaNCorr)
614+
featuresToDrop, featuresWithNaNCorr, hasFeatureFeature = false)
615615
}
616616

617617
it should "not fail when maps have the same keys" in {
@@ -666,7 +666,8 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
666666
expectedFeatNames: Seq[String],
667667
expectedFeaturesToDrop: Seq[String],
668668
expectedCorrFeatNamesIsNan: Seq[String],
669-
ignoredNames: Seq[String] = Seq.empty
669+
ignoredNames: Seq[String] = Seq.empty,
670+
hasFeatureFeature: Boolean = false
670671
): Unit = {
671672
transformedData.select(outputColName).collect().foreach { case Row(features: Vector) =>
672673
features.toArray.length equals
@@ -682,6 +683,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP
682683
} should contain theSameElementsAs expectedCorrFeatNamesIsNan
683684
summary.correlations.featuresIn should contain theSameElementsAs expectedFeatNames.diff(ignoredNames)
684685
summary.dropped should contain theSameElementsAs expectedFeaturesToDrop
686+
summary.correlations.valuesWithFeatures.nonEmpty shouldEqual hasFeatureFeature
685687
}
686688

687689
private def getMetadata(outputColName: String, transformedData: DataFrame): Metadata = {

0 commit comments

Comments
 (0)