Skip to content

Commit 64bb116

Browse files
iindyktfx-copybara
authored andcommitted
Allows TensorRepresentation and its source feature to share name.
Sharing the name makes working with simpler schemas less error-prone and ambiguous. PiperOrigin-RevId: 478584256
1 parent 5842bbc commit 64bb116

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
in the same order as the terms in documents.
1010
* `schema_utils.schema_as_feature_spec` now supports struct features as a way
1111
to describe `tf.SequenceExample` data.
12+
* TensorRepresentations in schema used for
13+
`schema_utils.schema_as_feature_spec` can now share name with their source
14+
features.
1215

1316
## Bug Fixes and Other Changes
1417

tensorflow_transform/tf_metadata/schema_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,15 +319,18 @@ def schema_as_feature_spec(
319319
schema_proto, TENSOR_REPRESENTATION_GROUP))
320320
if tensor_representations is not None:
321321
for name, tensor_representation in tensor_representations.items():
322-
if name in feature_by_name:
323-
raise ValueError(
324-
'Ragged TensorRepresentation name "{}" conflicts with a different '
325-
'feature in the same schema.'.format(name))
326322
(feature_spec[name], domains[name]) = (
327323
_ragged_tensor_representation_as_feature_spec(name,
328324
tensor_representation,
329325
feature_by_name,
330326
string_domains))
327+
# At this point `feature_by_name` does not have source features for this
328+
# tensor representation. If there's still a feature with the same name,
329+
# then it would result in a name conflict.
330+
if name in feature_by_name:
331+
raise ValueError(
332+
'Ragged TensorRepresentation name "{}" conflicts with a different '
333+
'feature in the same schema.'.format(name))
331334

332335
# Generate a `tf.FixedLenFeature` or `tf.VarLenFeature` for each element of
333336
# `schema_proto.feature` that was not referenced by a `SparseFeature` or a

tensorflow_transform/tf_metadata/schema_utils_test_cases.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,38 @@
785785
row_splits_dtype=tf.int64),
786786
},
787787
},
788+
{
789+
'testcase_name':
790+
'ragged_tensor_and_feature_same_name',
791+
'ascii_proto':
792+
"""
793+
feature {
794+
name: "ragged"
795+
type: FLOAT
796+
}
797+
tensor_representation_group {
798+
key: ""
799+
value {
800+
tensor_representation {
801+
key: "ragged"
802+
value {
803+
ragged_tensor {
804+
feature_path { step: "ragged" }
805+
}
806+
}
807+
}
808+
}
809+
}
810+
""",
811+
'feature_spec': {
812+
'ragged':
813+
tf.io.RaggedFeature(
814+
tf.float32,
815+
value_key='ragged',
816+
partitions=[],
817+
row_splits_dtype=tf.int64),
818+
},
819+
},
788820
])
789821

790822
NON_ROUNDTRIP_SCHEMAS.extend([{

0 commit comments

Comments
 (0)