Skip to content

Commit d59b84d

Browse files
zoyahavtfx-copybara
authored andcommitted
Automated rollback of commit 067cf11
PiperOrigin-RevId: 551536436
1 parent 0ed5b29 commit d59b84d

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

tensorflow_transform/beam/impl.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,11 +1360,17 @@ def expand(self, dataset):
13601360

13611361
def _remove_columns_from_metadata(metadata, excluded_columns):
13621362
"""Remove columns from metadata without mutating original metadata."""
1363-
feature_spec, domains = schema_utils.schema_as_feature_spec(metadata.schema)
1364-
new_feature_spec = {name: spec for name, spec in feature_spec.items()
1365-
if name not in excluded_columns}
1366-
new_domains = {name: spec for name, spec in domains.items()
1367-
if name not in excluded_columns}
1363+
generated = schema_utils.schema_as_feature_spec(metadata.schema)
1364+
new_feature_spec = {
1365+
name: spec
1366+
for name, spec in generated.feature_spec.items()
1367+
if name not in excluded_columns
1368+
}
1369+
new_domains = {
1370+
name: spec
1371+
for name, spec in generated.domains.items()
1372+
if name not in excluded_columns
1373+
}
13681374
return dataset_metadata.DatasetMetadata.from_feature_spec(
13691375
new_feature_spec, new_domains)
13701376

tensorflow_transform/tf_metadata/schema_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
# limitations under the License.
1414
"""Utilities for using the tf.Metadata Schema within TensorFlow."""
1515

16+
import dataclasses
1617
import typing
1718
from typing import Dict, List, Mapping, Optional, Tuple, Union
1819

1920
import tensorflow as tf
20-
2121
from tensorflow_transform import common_types
2222
from tensorflow_transform.tf_metadata import schema_utils_legacy
2323
from tfx_bsl.tfxio import tensor_representation_util
24-
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
25-
# once the Spark issue is resolved.
26-
from tfx_bsl.types import tfx_namedtuple
2724

2825
from tensorflow_metadata.proto.v0 import path_pb2
2926
from tensorflow_metadata.proto.v0 import schema_pb2
@@ -245,10 +242,15 @@ def _set_domain(name, feature, domain):
245242
raise ValueError('Feature "{}" has invalid domain {}'.format(name, domain))
246243

247244

248-
SchemaAsFeatureSpecResult = tfx_namedtuple.TypedNamedTuple(
249-
'SchemaAsFeatureSpecResult',
250-
[('feature_spec', Dict[str, common_types.FeatureSpecType]),
251-
('domains', Dict[str, common_types.DomainType])])
245+
@dataclasses.dataclass(frozen=True)
246+
class SchemaAsFeatureSpecResult:
247+
feature_spec: Dict[str, common_types.FeatureSpecType]
248+
domains: Dict[str, common_types.DomainType]
249+
250+
# This is needed because many users unpack this with:
251+
# `feature_spec, domains = schema_utils.schema_as_feature_spec()`.
252+
def __iter__(self):
253+
return (getattr(self, field.name) for field in dataclasses.fields(self))
252254

253255

254256
def _standardize_default_value(

tensorflow_transform/tf_metadata/schema_utils_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def test_schema_as_feature_spec(
4646
schema_utils_legacy.set_generate_legacy_feature_spec(
4747
schema_proto, generate_legacy_feature_spec)
4848
result = schema_utils.schema_as_feature_spec(schema_proto)
49-
self.assertEqual(result, (feature_spec, domains or {}))
49+
self.assertEqual(
50+
result,
51+
schema_utils.SchemaAsFeatureSpecResult(feature_spec, domains or {}),
52+
)
5053

5154
@parameterized.named_parameters(
5255
*schema_utils_test_cases.INVALID_SCHEMA_PROTOS)

0 commit comments

Comments
 (0)