Skip to content

Commit b2ede54

Browse files
iindyktfx-copybara
authored andcommitted
Adds support for handling struct features in schema_as_feature_spec.
Presence of a struct feature is interpreted as physical format being `tf.SequenceExample`. `schema_as_feature_spec` returns a union of context and sequence features in such case. This represents how `preprocessing_fn` sees parsed features (e.g. with TfSequenceExampleRecord tfxio). PiperOrigin-RevId: 474625238
1 parent e11c104 commit b2ede54

File tree

3 files changed

+299
-5
lines changed

3 files changed

+299
-5
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* Introduced `tft.experimental.document_frequency` and `tft.experimental.idf`
88
which map each term to its document frequency and inverse document frequency
99
in the same order as the terms in documents.
10+
* `schema_utils.schema_as_feature_spec` now supports struct features as a way
11+
to describe `tf.SequenceExample` data.
1012

1113
## Bug Fixes and Other Changes
1214

tensorflow_transform/tf_metadata/schema_utils.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Utilities for using the tf.Metadata Schema within TensorFlow."""
1515

1616
import typing
17-
from typing import Dict, List, Mapping, Optional, Tuple
17+
from typing import Dict, List, Mapping, Optional, Tuple, Union
1818

1919
import tensorflow as tf
2020

@@ -262,15 +262,19 @@ def schema_as_feature_spec(
262262
For a Feature with a FixedShape we generate a FixedLenFeature with no default.
263263
For a Feature without a FixedShape we generate a VarLenFeature. For a
264264
SparseFeature we generate a SparseFeature.
265+
If schema contains struct feature, then it must also contain
266+
TensorRepresentations and is assumed to describe SequenceExample data. The
267+
result in such case is union of context and sequence feature specs.
265268
266269
Args:
267270
schema_proto: A Schema proto.
268271
269272
Returns:
270273
A pair (feature spec, domains) where feature spec is a dict whose keys are
271-
feature names and values are instances of FixedLenFeature, VarLenFeature
272-
or SparseFeature, and `domains` is a dict whose keys are feature names
273-
and values are one of the `domain_info` oneof, e.g. IntDomain.
274+
feature names and values are instances of FixedLenFeature,
275+
VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict
276+
whose keys are feature names and values are one of the `domain_info`
277+
oneof, e.g. IntDomain.
274278
275279
Raises:
276280
ValueError: If the schema proto is invalid.
@@ -285,6 +289,12 @@ def schema_as_feature_spec(
285289

286290
if schema_utils_legacy.get_generate_legacy_feature_spec(schema_proto):
287291
return _legacy_schema_as_feature_spec(schema_proto)
292+
293+
# Presence of a struct means that data's physical format is tf.SequenceExample
294+
# and the struct contains sequence features.
295+
if any(feature.type == schema_pb2.STRUCT for feature in schema_proto.feature):
296+
return _sequence_schema_as_feature_spec(schema_proto)
297+
288298
feature_spec = {}
289299
# Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For
290300
# sparse features, will hold the domain_info of the values feature. Features
@@ -335,7 +345,71 @@ def schema_as_feature_spec(
335345
return SchemaAsFeatureSpecResult(feature_spec, domains)
336346

337347

338-
def _get_string_domains(schema):
348+
def _sequence_schema_as_feature_spec(
349+
schema: schema_pb2.Schema) -> SchemaAsFeatureSpecResult:
350+
"""Generates a feature spec from a Schema describing tf.SequenceExample data.
351+
352+
See `tensor_representation_util.CreateTfSequenceExampleParserConfig`s
353+
docstring for feature spec generation rules.
354+
We mix context and sequence feature specs to replicate how preprocessing_fn
355+
sees input features -- as top-level values of a single `inputs` dict. Note
356+
that this makes the feature spec generation irreversible without additional
357+
input since it's no longer possible to distinguish context and sequence
358+
features to produce the original schema.
359+
360+
Args:
361+
schema: A TFMD Schema proto.
362+
363+
Returns:
364+
A pair (feature spec, domains) where feature spec is a dict whose keys are
365+
feature names and values are instances of FixedLenFeature,
366+
VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict
367+
whose keys are feature names and values are one of the `domain_info`
368+
oneof, e.g. IntDomain.
369+
370+
Raises:
371+
ValueError: If `TensorRepresentation`s in the schema result in feature specs
372+
that are not supported.
373+
"""
374+
(context_feature_spec, sequence_feature_spec
375+
) = tensor_representation_util.CreateTfSequenceExampleParserConfig(schema)
376+
feature_spec = {**context_feature_spec, **sequence_feature_spec}
377+
string_domains = _get_string_domains(schema)
378+
domain_by_feature_name = _get_source_feature_domains(schema, string_domains)
379+
domains = {}
380+
for name, spec in feature_spec.items():
381+
if isinstance(spec, (tf.io.FixedLenFeature, tf.io.VarLenFeature)):
382+
source_feature_name = name
383+
elif isinstance(
384+
spec, tf.io.SparseFeature) or common_types.is_ragged_feature(spec):
385+
source_feature_name = spec.value_key
386+
else:
387+
raise ValueError('spec is not recognized')
388+
if source_feature_name in domain_by_feature_name:
389+
domains[name] = domain_by_feature_name[source_feature_name]
390+
return SchemaAsFeatureSpecResult(feature_spec, domains)
391+
392+
393+
def _get_source_feature_domains(
394+
schema_or_domain: Union[schema_pb2.Schema, schema_pb2.StructDomain],
395+
string_domains: Dict[str, schema_pb2.StringDomain]
396+
) -> Dict[str, common_types.DomainType]:
397+
"""Recursively extracts domains of all source features in the schema."""
398+
result = {}
399+
for feature in schema_or_domain.feature:
400+
domain_info = feature.WhichOneof('domain_info')
401+
if domain_info == 'struct_domain':
402+
result.update(
403+
_get_source_feature_domains(feature.struct_domain, string_domains))
404+
else:
405+
domain = _get_domain(feature, string_domains)
406+
if domain is not None:
407+
result[feature.name] = domain
408+
return result
409+
410+
411+
def _get_string_domains(
412+
schema: schema_pb2.Schema) -> Dict[str, schema_pb2.StringDomain]:
339413
return {domain.name: domain for domain in schema.string_domain}
340414

341415

tensorflow_transform/tf_metadata/schema_utils_test_cases.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,224 @@
787787
},
788788
])
789789

790+
NON_ROUNDTRIP_SCHEMAS.extend([{
791+
'testcase_name':
792+
'sequence',
793+
'ascii_proto':
794+
"""
795+
feature {
796+
name: "int_feature"
797+
type: INT
798+
value_count {
799+
min: 1
800+
max: 1
801+
}
802+
}
803+
804+
feature {
805+
name: "##SEQUENCE##"
806+
type: STRUCT
807+
struct_domain {
808+
feature {
809+
name: "int_feature"
810+
type: INT
811+
value_count {
812+
min: 0
813+
max: 2
814+
}
815+
}
816+
}
817+
}
818+
tensor_representation_group {
819+
key: ""
820+
value {
821+
tensor_representation {
822+
key: "int_feature"
823+
value { varlen_sparse_tensor { column_name: "int_feature" } }
824+
}
825+
tensor_representation {
826+
key: "seq_int_feature"
827+
value {
828+
ragged_tensor {
829+
feature_path { step: "##SEQUENCE##" step: "int_feature" }
830+
}
831+
}
832+
}
833+
}
834+
}
835+
""",
836+
'feature_spec': {
837+
'int_feature':
838+
tf.io.VarLenFeature(dtype=tf.int64),
839+
'seq_int_feature':
840+
tf.io.RaggedFeature(
841+
dtype=tf.int64,
842+
value_key='int_feature',
843+
partitions=[],
844+
row_splits_dtype=tf.int64,
845+
validate=False),
846+
},
847+
}, {
848+
'testcase_name':
849+
'sequence_no_context',
850+
'ascii_proto':
851+
"""
852+
feature {
853+
name: "##SEQUENCE##"
854+
type: STRUCT
855+
struct_domain {
856+
feature {
857+
name: "x"
858+
type: INT
859+
value_count {
860+
min: 0
861+
max: 2
862+
}
863+
}
864+
}
865+
}
866+
tensor_representation_group {
867+
key: ""
868+
value {
869+
tensor_representation {
870+
key: "x"
871+
value { ragged_tensor {
872+
feature_path { step: "##SEQUENCE##" step: "x" } } }
873+
}
874+
}
875+
}
876+
""",
877+
'feature_spec': {
878+
'x':
879+
tf.io.RaggedFeature(
880+
dtype=tf.int64,
881+
value_key='x',
882+
partitions=[],
883+
row_splits_dtype=tf.int64,
884+
validate=False),
885+
},
886+
}, {
887+
'testcase_name':
888+
'sequence_with_domains',
889+
'ascii_proto':
890+
"""
891+
feature {
892+
name: "int_feature"
893+
type: INT
894+
value_count {
895+
min: 1
896+
max: 1
897+
}
898+
int_domain { min: 0 max: 9 }
899+
}
900+
901+
feature {
902+
name: "##SEQUENCE##"
903+
type: STRUCT
904+
struct_domain {
905+
feature {
906+
name: "float_feature"
907+
type: FLOAT
908+
value_count {
909+
min: 0
910+
max: 2
911+
}
912+
float_domain { min: 1.0}
913+
}
914+
}
915+
}
916+
tensor_representation_group {
917+
key: ""
918+
value {
919+
tensor_representation {
920+
key: "int_feature"
921+
value { varlen_sparse_tensor { column_name: "int_feature" } }
922+
}
923+
tensor_representation {
924+
key: "seq_float_feature"
925+
value {
926+
ragged_tensor {
927+
feature_path { step: "##SEQUENCE##" step: "float_feature" }
928+
}
929+
}
930+
}
931+
}
932+
}
933+
""",
934+
'feature_spec': {
935+
'int_feature':
936+
tf.io.VarLenFeature(dtype=tf.int64),
937+
'seq_float_feature':
938+
tf.io.RaggedFeature(
939+
dtype=tf.float32,
940+
value_key='float_feature',
941+
partitions=[],
942+
row_splits_dtype=tf.int64,
943+
validate=False),
944+
},
945+
'domains': {
946+
'int_feature': schema_pb2.IntDomain(min=0, max=9),
947+
'seq_float_feature': schema_pb2.FloatDomain(min=1.0)
948+
}
949+
}, {
950+
'testcase_name':
951+
'sequence_with_string_domain',
952+
'ascii_proto':
953+
"""
954+
feature {
955+
name: "int_feature"
956+
type: INT
957+
}
958+
959+
feature {
960+
name: "##SEQUENCE##"
961+
type: STRUCT
962+
struct_domain {
963+
feature {
964+
name: "string_feature"
965+
type: BYTES
966+
value_count {
967+
min: 0
968+
max: 2
969+
}
970+
string_domain {value: "a" value: "b"}
971+
}
972+
}
973+
}
974+
tensor_representation_group {
975+
key: ""
976+
value {
977+
tensor_representation {
978+
key: "int_feature"
979+
value { varlen_sparse_tensor { column_name: "int_feature" } }
980+
}
981+
tensor_representation {
982+
key: "seq_string_feature"
983+
value {
984+
ragged_tensor {
985+
feature_path { step: "##SEQUENCE##" step: "string_feature" }
986+
}
987+
}
988+
}
989+
}
990+
}
991+
""",
992+
'feature_spec': {
993+
'int_feature':
994+
tf.io.VarLenFeature(dtype=tf.int64),
995+
'seq_string_feature':
996+
tf.io.RaggedFeature(
997+
dtype=tf.string,
998+
value_key='string_feature',
999+
partitions=[],
1000+
row_splits_dtype=tf.int64,
1001+
validate=False),
1002+
},
1003+
'domains': {
1004+
'seq_string_feature': schema_pb2.StringDomain(value=['a', 'b'])
1005+
}
1006+
}])
1007+
7901008
INVALID_SCHEMA_PROTOS.extend([{
7911009
'testcase_name':
7921010
'ragged_feature_non_int_row_lengths',

0 commit comments

Comments
 (0)