Skip to content

Commit 1ac6051

Browse files
cavenesstfx-copybara
authored andcommitted
Modify how TFDV updates the schema in response to a value nestedness mismatch anomaly. Instead of clearing value_count(s), TFDV updates the value_count(s) based on the stats, just as it does when generating a schema anew.
PiperOrigin-RevId: 325452888
1 parent cf9d57a commit 1ac6051

File tree

3 files changed

+197
-76
lines changed

3 files changed

+197
-76
lines changed

tensorflow_data_validation/anomalies/feature_util.cc

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,42 @@ absl::optional<FeatureStatsView> GetControlStats(
7171
}
7272
}
7373

74+
void InitValueCount(const FeatureStatsView& feature_stats_view,
75+
Feature* feature) {
76+
// Set value_counts or value_count, depending on whether the feature's values
77+
// are nested.
78+
const std::vector<std::pair<int, int>> min_max_num_values =
79+
feature_stats_view.GetMinMaxNumValues();
80+
auto set_value_count = [](int min_num_values, int max_num_values,
81+
ValueCount* value_count) {
82+
if (min_num_values > 0) {
83+
if (min_num_values == max_num_values) {
84+
// Set min and max value count in the schema if they are same. This
85+
// would allow required features with same valency to be parsed as dense
86+
// tensors in TFT.
87+
value_count->set_min(min_num_values);
88+
value_count->set_max(max_num_values);
89+
} else {
90+
value_count->set_min(1);
91+
}
92+
}
93+
};
94+
if (feature_stats_view.HasNestedValues()) {
95+
for (int i = 0; i < min_max_num_values.size(); i++) {
96+
set_value_count(min_max_num_values[i].first, min_max_num_values[i].second,
97+
feature->mutable_value_counts()->add_value_count());
98+
}
99+
} else if (min_max_num_values.size() == 1 &&
100+
min_max_num_values[0].first > 0) {
101+
set_value_count(min_max_num_values[0].first, min_max_num_values[0].second,
102+
feature->mutable_value_count());
103+
}
104+
}
105+
74106
std::vector<Description> UpdateValueCount(
75-
const std::vector<std::pair<int, int>>& min_max_num_values,
76-
Feature* feature) {
107+
const FeatureStatsView& feature_stats_view, Feature* feature) {
108+
const std::vector<std::pair<int, int>> min_max_num_values =
109+
feature_stats_view.GetMinMaxNumValues();
77110
std::vector<Description> description;
78111
if (min_max_num_values.size() > 1) {
79112
description.push_back(
@@ -82,6 +115,7 @@ std::vector<Description> UpdateValueCount(
82115
"feature > 1. For features with nestedness levels greater than 1, "
83116
"value_counts, not value_count, should be specified."});
84117
feature->clear_value_count();
118+
InitValueCount(feature_stats_view, feature);
85119
return description;
86120
}
87121
if (feature->value_count().has_min() &&
@@ -106,15 +140,17 @@ std::vector<Description> UpdateValueCount(
106140
}
107141

108142
std::vector<Description> UpdateValueCounts(
109-
const std::vector<std::pair<int, int>>& min_max_num_values,
110-
Feature* feature) {
143+
const FeatureStatsView& feature_stats_view, Feature* feature) {
144+
const std::vector<std::pair<int, int>> min_max_num_values =
145+
feature_stats_view.GetMinMaxNumValues();
111146
std::vector<Description> description;
112147
if (feature->value_counts().value_count_size() != min_max_num_values.size()) {
113148
description.push_back(
114149
{AnomalyInfo::VALUE_NESTEDNESS_MISMATCH, kValueNestednessMismatch,
115150
"The values have a different nest level than expected. Value counts "
116151
"will not be checked."});
117152
feature->clear_value_counts();
153+
InitValueCount(feature_stats_view, feature);
118154
return description;
119155
}
120156
for (int i = 0; i < feature->value_counts().value_count_size(); ++i) {
@@ -156,13 +192,11 @@ std::vector<Description> UpdateFeatureValueCounts(
156192
if (!feature->has_value_count() && !feature->has_value_counts()) {
157193
return {};
158194
}
159-
const std::vector<std::pair<int, int>> min_max_num_values =
160-
feature_stats_view.GetMinMaxNumValues();
161195
if (feature->has_value_count()) {
162-
return UpdateValueCount(min_max_num_values, feature);
196+
return UpdateValueCount(feature_stats_view, feature);
163197
}
164198
if (feature->has_value_counts()) {
165-
return UpdateValueCounts(min_max_num_values, feature);
199+
return UpdateValueCounts(feature_stats_view, feature);
166200
}
167201
return {};
168202
}
@@ -404,34 +438,7 @@ void InitValueCountAndPresence(const FeatureStatsView& feature_stats_view,
404438
// Required feature.
405439
feature->mutable_presence()->set_min_fraction(1.0);
406440
}
407-
// Set value_counts or value_count, depending on whether the feature's values
408-
// are nested.
409-
const std::vector<std::pair<int, int>> min_max_num_values =
410-
feature_stats_view.GetMinMaxNumValues();
411-
auto set_value_count = [](int min_num_values, int max_num_values,
412-
ValueCount* value_count) {
413-
if (min_num_values > 0) {
414-
if (min_num_values == max_num_values) {
415-
// Set min and max value count in the schema if they are same. This
416-
// would allow required features with same valency to be parsed as dense
417-
// tensors in TFT.
418-
value_count->set_min(min_num_values);
419-
value_count->set_max(max_num_values);
420-
} else {
421-
value_count->set_min(1);
422-
}
423-
}
424-
};
425-
if (feature_stats_view.HasNestedValues()) {
426-
for (int i = 0; i < min_max_num_values.size(); i++) {
427-
set_value_count(min_max_num_values[i].first, min_max_num_values[i].second,
428-
feature->mutable_value_counts()->add_value_count());
429-
}
430-
} else if (min_max_num_values.size() == 1 &&
431-
min_max_num_values[0].first > 0) {
432-
set_value_count(min_max_num_values[0].first, min_max_num_values[0].second,
433-
feature->mutable_value_count());
434-
}
441+
InitValueCount(feature_stats_view, feature);
435442
}
436443

437444
std::vector<Description> UpdatePresence(

tensorflow_data_validation/anomalies/feature_util_test.cc

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,10 @@ GetUpdateFeatureValueCountsTests() {
730730
{AnomalyInfo::VALUE_NESTEDNESS_MISMATCH},
731731
{"Mismatched value nest level"},
732732
ParseTextProtoOrDie<Feature>(R"(value_count: { min: 1 max: 1 })"),
733-
Feature()},
733+
ParseTextProtoOrDie<Feature>(R"(value_counts: {
734+
value_count: { min: 1 max: 1 }
735+
value_count: { min: 1 max: 1 }
736+
})")},
734737
{"num_values_outside_value_count_bounds",
735738
ParseTextProtoOrDie<FeatureNameStatistics>(R"(
736739
name: 'feature'
@@ -802,44 +805,50 @@ GetUpdateFeatureValueCountsTests() {
802805
value_count: { min: 1 max: 1 }
803806
value_count: { min: 1 max: 1 }
804807
})")},
805-
{"value_counts_nestedness_mismatch",
806-
ParseTextProtoOrDie<FeatureNameStatistics>(R"(
807-
name: 'feature'
808-
type: FLOAT
809-
num_stats: {
810-
common_stats: {
811-
num_missing: 0
812-
num_non_missing: 10
813-
min_num_values: 1
814-
max_num_values: 1
815-
presence_and_valency_stats {
816-
num_missing: 0
817-
num_non_missing: 10
818-
min_num_values: 1
819-
max_num_values: 1
820-
}
821-
presence_and_valency_stats {
822-
num_missing: 0
823-
num_non_missing: 10
824-
min_num_values: 1
825-
max_num_values: 1
826-
}
827-
presence_and_valency_stats {
828-
num_missing: 0
829-
num_non_missing: 10
830-
min_num_values: 1
831-
max_num_values: 1
832-
}
833-
}
834-
})"),
835-
false,
836-
{AnomalyInfo::VALUE_NESTEDNESS_MISMATCH},
837-
{"Mismatched value nest level"},
838-
ParseTextProtoOrDie<Feature>(R"(value_counts: {
839-
value_count: { min: 1 max: 1 }
840-
value_count: { min: 1 max: 1 }
841-
})"),
842-
Feature()},
808+
{
809+
"value_counts_nestedness_mismatch",
810+
ParseTextProtoOrDie<FeatureNameStatistics>(R"(
811+
name: 'feature'
812+
type: FLOAT
813+
num_stats: {
814+
common_stats: {
815+
num_missing: 0
816+
num_non_missing: 10
817+
min_num_values: 1
818+
max_num_values: 1
819+
presence_and_valency_stats {
820+
num_missing: 0
821+
num_non_missing: 10
822+
min_num_values: 1
823+
max_num_values: 1
824+
}
825+
presence_and_valency_stats {
826+
num_missing: 0
827+
num_non_missing: 10
828+
min_num_values: 1
829+
max_num_values: 1
830+
}
831+
presence_and_valency_stats {
832+
num_missing: 0
833+
num_non_missing: 10
834+
min_num_values: 1
835+
max_num_values: 1
836+
}
837+
}
838+
})"),
839+
false,
840+
{AnomalyInfo::VALUE_NESTEDNESS_MISMATCH},
841+
{"Mismatched value nest level"},
842+
ParseTextProtoOrDie<Feature>(R"(value_counts: {
843+
value_count: { min: 1 max: 1 }
844+
value_count: { min: 1 max: 1 }
845+
})"),
846+
ParseTextProtoOrDie<Feature>(R"(value_counts: {
847+
value_count: { min: 1 max: 1 }
848+
value_count: { min: 1 max: 1 }
849+
value_count: { min: 1 max: 1 }
850+
})"),
851+
},
843852
{"num_values_outside_value_counts_bounds",
844853
ParseTextProtoOrDie<FeatureNameStatistics>(R"(
845854
name: 'feature'

tensorflow_data_validation/integration_tests/sequence_example_e2e_test.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,103 @@
15471547
anomaly_name_format: SERIALIZED_PATH
15481548
"""
15491549

1550+
_BASIC_SCHEMA_FROM_UPDATE = """
1551+
feature {
1552+
name: "context_bytes_feature"
1553+
value_count {
1554+
min: 1
1555+
max: 1
1556+
}
1557+
type: BYTES
1558+
bool_domain {
1559+
true_value: "1"
1560+
false_value: "0"
1561+
}
1562+
presence {
1563+
min_fraction: 1.0
1564+
min_count: 1
1565+
}
1566+
}
1567+
feature {
1568+
name: "context_int64_feature"
1569+
type: INT
1570+
presence {
1571+
min_count: 1
1572+
}
1573+
}
1574+
feature {
1575+
name: "example_weight"
1576+
value_count {
1577+
min: 1
1578+
max: 1
1579+
}
1580+
type: FLOAT
1581+
presence {
1582+
min_fraction: 1.0
1583+
min_count: 1
1584+
}
1585+
}
1586+
feature {
1587+
name: "label"
1588+
value_count {
1589+
min: 1
1590+
max: 1
1591+
}
1592+
type: FLOAT
1593+
presence {
1594+
min_fraction: 1.0
1595+
min_count: 1
1596+
}
1597+
}
1598+
feature {
1599+
name: "##SEQUENCE##"
1600+
value_count {
1601+
min: 1
1602+
max: 1
1603+
}
1604+
type: STRUCT
1605+
presence {
1606+
min_fraction: 1.0
1607+
min_count: 1
1608+
}
1609+
struct_domain {
1610+
feature {
1611+
name: "sequence_float_feature"
1612+
type: FLOAT
1613+
presence {
1614+
min_count: 1
1615+
}
1616+
value_counts {
1617+
value_count {
1618+
min: 1
1619+
max: 1
1620+
}
1621+
value_count {
1622+
min: 2
1623+
max: 2
1624+
}
1625+
}
1626+
}
1627+
feature {
1628+
name: "sequence_int64_feature"
1629+
type: INT
1630+
presence {
1631+
min_fraction: 1.0
1632+
min_count: 1
1633+
}
1634+
value_counts {
1635+
value_count {
1636+
min: 1
1637+
}
1638+
value_count {
1639+
max: 3
1640+
}
1641+
}
1642+
}
1643+
}
1644+
}
1645+
"""
1646+
15501647
# Do not inline the goldens in _TEST_CASES. This way indentation is easier to
15511648
# manage. The rule is to have no first level indent for goldens.
15521649
_TEST_CASES = [
@@ -1562,6 +1659,7 @@
15621659
expected_inferred_schema_pbtxt=_BASIC_GOLDEN_INFERRED_SCHEMA,
15631660
schema_for_validation_pbtxt=_BASIC_SCHEMA_FOR_VALIDATION,
15641661
expected_anomalies_pbtxt=_BASIC_GOLDEN_ANOMALIES,
1662+
expected_updated_schema_pbtxt=_BASIC_SCHEMA_FROM_UPDATE,
15651663
),
15661664
dict(
15671665
testcase_name='weight_and_label',
@@ -1577,6 +1675,7 @@
15771675
expected_inferred_schema_pbtxt=_BASIC_GOLDEN_INFERRED_SCHEMA,
15781676
schema_for_validation_pbtxt=_BASIC_SCHEMA_FOR_VALIDATION,
15791677
expected_anomalies_pbtxt=_BASIC_GOLDEN_ANOMALIES,
1678+
expected_updated_schema_pbtxt=_BASIC_SCHEMA_FROM_UPDATE,
15801679
)
15811680
]
15821681

@@ -1639,7 +1738,7 @@ def _assert_features_equal(lhs, rhs):
16391738
@parameterized.named_parameters(*_TEST_CASES)
16401739
def test_e2e(self, stats_options, expected_stats_pbtxt,
16411740
expected_inferred_schema_pbtxt, schema_for_validation_pbtxt,
1642-
expected_anomalies_pbtxt):
1741+
expected_anomalies_pbtxt, expected_updated_schema_pbtxt):
16431742
tfxio = tf_sequence_example_record.TFSequenceExampleRecord(
16441743
self._input_file, ['tfdv', 'test'])
16451744
stats_file = os.path.join(self._output_dir, 'stats')
@@ -1674,6 +1773,12 @@ def test_e2e(self, stats_options, expected_stats_pbtxt,
16741773
actual_anomalies,
16751774
text_format.Parse(expected_anomalies_pbtxt, anomalies_pb2.Anomalies()))
16761775

1776+
actual_updated_schema = tfdv.update_schema(
1777+
schema_for_validation, actual_stats, infer_feature_shape=False)
1778+
self._assert_schema_equal(
1779+
actual_updated_schema,
1780+
text_format.Parse(expected_updated_schema_pbtxt, schema_pb2.Schema()))
1781+
16771782

16781783
if __name__ == '__main__':
16791784
absltest.main()

0 commit comments

Comments
 (0)