Skip to content

Commit 067cf11

Browse files
zoyahavtfx-copybara
authored andcommitted
Automated rollback of commit c04a131
PiperOrigin-RevId: 550599290
1 parent c04a131 commit 067cf11

File tree

3 files changed

+14
-25
lines changed

3 files changed

+14
-25
lines changed

tensorflow_transform/beam/impl.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,17 +1360,11 @@ def expand(self, dataset):
13601360

13611361
def _remove_columns_from_metadata(metadata, excluded_columns):
13621362
"""Remove columns from metadata without mutating original metadata."""
1363-
generated = schema_utils.schema_as_feature_spec(metadata.schema)
1364-
new_feature_spec = {
1365-
name: spec
1366-
for name, spec in generated.feature_spec.items()
1367-
if name not in excluded_columns
1368-
}
1369-
new_domains = {
1370-
name: spec
1371-
for name, spec in generated.domains.items()
1372-
if name not in excluded_columns
1373-
}
1363+
feature_spec, domains = schema_utils.schema_as_feature_spec(metadata.schema)
1364+
new_feature_spec = {name: spec for name, spec in feature_spec.items()
1365+
if name not in excluded_columns}
1366+
new_domains = {name: spec for name, spec in domains.items()
1367+
if name not in excluded_columns}
13741368
return dataset_metadata.DatasetMetadata.from_feature_spec(
13751369
new_feature_spec, new_domains)
13761370

tensorflow_transform/tf_metadata/schema_utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
"""Utilities for using the tf.Metadata Schema within TensorFlow."""
1515

16-
import dataclasses
1716
import typing
1817
from typing import Dict, List, Mapping, Optional, Tuple, Union
1918

2019
import tensorflow as tf
20+
2121
from tensorflow_transform import common_types
2222
from tensorflow_transform.tf_metadata import schema_utils_legacy
2323
from tfx_bsl.tfxio import tensor_representation_util
24+
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
25+
# once the Spark issue is resolved.
26+
from tfx_bsl.types import tfx_namedtuple
2427

2528
from tensorflow_metadata.proto.v0 import path_pb2
2629
from tensorflow_metadata.proto.v0 import schema_pb2
@@ -242,15 +245,10 @@ def _set_domain(name, feature, domain):
242245
raise ValueError('Feature "{}" has invalid domain {}'.format(name, domain))
243246

244247

245-
@dataclasses.dataclass(frozen=True)
246-
class SchemaAsFeatureSpecResult:
247-
feature_spec: Dict[str, common_types.FeatureSpecType]
248-
domains: Dict[str, common_types.DomainType]
249-
250-
# This is needed because many users unpack this with:
251-
# `feature_spec, domains = schema_utils.schema_as_feature_spec()`.
252-
def __iter__(self):
253-
return (getattr(self, field.name) for field in dataclasses.fields(self))
248+
SchemaAsFeatureSpecResult = tfx_namedtuple.TypedNamedTuple(
249+
'SchemaAsFeatureSpecResult',
250+
[('feature_spec', Dict[str, common_types.FeatureSpecType]),
251+
('domains', Dict[str, common_types.DomainType])])
254252

255253

256254
def _standardize_default_value(

tensorflow_transform/tf_metadata/schema_utils_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ def test_schema_as_feature_spec(
4646
schema_utils_legacy.set_generate_legacy_feature_spec(
4747
schema_proto, generate_legacy_feature_spec)
4848
result = schema_utils.schema_as_feature_spec(schema_proto)
49-
self.assertEqual(
50-
result,
51-
schema_utils.SchemaAsFeatureSpecResult(feature_spec, domains or {}),
52-
)
49+
self.assertEqual(result, (feature_spec, domains or {}))
5350

5451
@parameterized.named_parameters(
5552
*schema_utils_test_cases.INVALID_SCHEMA_PROTOS)

0 commit comments

Comments
 (0)