Skip to content

Commit 4356424

Browse files
iindyktfx-copybara
authored andcommitted
Switching schema_as_feature_spec to tfx_bsl inference logic.
PiperOrigin-RevId: 493037063
1 parent 6fdf4e3 commit 4356424

File tree

3 files changed

+73
-409
lines changed

3 files changed

+73
-409
lines changed

tensorflow_transform/tf_metadata/schema_utils.py

Lines changed: 36 additions & 294 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _sparse_feature_from_feature_spec(spec, name, domains):
152152
s.value if isinstance(s, tf.compat.v1.Dimension) else s
153153
for s in spec.size
154154
]
155+
spec_size = [s if s != -1 else None for s in spec_size]
155156
int_domains = [
156157
schema_pb2.IntDomain(min=0, max=size - 1) if size is not None else None
157158
for size in spec_size
@@ -255,6 +256,23 @@ def _set_domain(name, feature, domain):
255256
RAGGED_TENSOR_TAG = 'ragged_tensor'
256257

257258

259+
def _standardize_default_value(
260+
spec: tf.io.FixedLenFeature) -> tf.io.FixedLenFeature:
261+
"""Converts bytes to strings and unwraps lists with a single element."""
262+
if spec.default_value is None:
263+
return spec
264+
default_value = spec.default_value
265+
assert isinstance(default_value, list), spec.default_value
266+
# Convert bytes to string
267+
if spec.dtype == tf.string:
268+
default_value = [value.decode('utf-8') for value in default_value]
269+
# Unwrap a list with a single element.
270+
if len(default_value) == 1:
271+
default_value = default_value[0]
272+
return tf.io.FixedLenFeature(
273+
shape=spec.shape, dtype=spec.dtype, default_value=default_value)
274+
275+
258276
def schema_as_feature_spec(
259277
schema_proto: schema_pb2.Schema) -> SchemaAsFeatureSpecResult:
260278
"""Generates a feature spec from a Schema proto.
@@ -279,72 +297,36 @@ def schema_as_feature_spec(
279297
Raises:
280298
ValueError: If the schema proto is invalid.
281299
"""
282-
for feature in schema_proto.feature:
283-
if RAGGED_TENSOR_TAG in feature.annotation.tag:
284-
raise ValueError(
285-
'Feature "{}" had tag "{}". Features represented by a '
286-
'RaggedTensor cannot be serialized/deserialized to Example proto or '
287-
'other formats, and cannot have a feature spec generated for '
288-
'them.'.format(feature.name, RAGGED_TENSOR_TAG))
289-
290-
if schema_utils_legacy.get_generate_legacy_feature_spec(schema_proto):
291-
return _legacy_schema_as_feature_spec(schema_proto)
292300

293301
# Presence of a struct means that data's physical format is tf.SequenceExample
294302
# and the struct contains sequence features.
295303
if any(feature.type == schema_pb2.STRUCT for feature in schema_proto.feature):
296304
return _sequence_schema_as_feature_spec(schema_proto)
297305

306+
tensor_representations = (
307+
tensor_representation_util.InferTensorRepresentationsFromMixedSchema(
308+
schema_proto))
309+
298310
feature_spec = {}
299311
# Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For
300312
# sparse features, will hold the domain_info of the values feature. Features
301313
# that do not have a domain set will not be present in `domains`.
302314
domains = {}
303-
feature_by_name = {feature.name: feature for feature in schema_proto.feature}
304315
string_domains = _get_string_domains(schema_proto)
305-
306-
# Generate a `tf.SparseFeature` for each element of
307-
# `schema_proto.sparse_feature`. This also removed the features from
308-
# feature_by_name.
309-
# TODO(KesterTong): Allow sparse features to share index features.
310-
for feature in schema_proto.sparse_feature:
311-
if _include_in_parsing_spec(feature):
312-
feature_spec[feature.name], domains[feature.name] = (
313-
_sparse_feature_as_feature_spec(feature, feature_by_name,
314-
string_domains))
315-
316-
# Handle ragged `TensorRepresentation`s.
317-
tensor_representations = (
318-
tensor_representation_util.GetTensorRepresentationsFromSchema(
319-
schema_proto, TENSOR_REPRESENTATION_GROUP))
320-
if tensor_representations is not None:
321-
for name, tensor_representation in tensor_representations.items():
322-
(feature_spec[name], domains[name]) = (
323-
_ragged_tensor_representation_as_feature_spec(name,
324-
tensor_representation,
325-
feature_by_name,
326-
string_domains))
327-
# At this point `feature_by_name` does not have source features for this
328-
# tensor representation. If there's still a feature with the same name,
329-
# then it would result in a name conflict.
330-
if name in feature_by_name:
331-
raise ValueError(
332-
'Ragged TensorRepresentation name "{}" conflicts with a different '
333-
'feature in the same schema.'.format(name))
334-
335-
# Generate a `tf.FixedLenFeature` or `tf.VarLenFeature` for each element of
336-
# `schema_proto.feature` that was not referenced by a `SparseFeature` or a
337-
# ragged `TensorRepresentation`.
338-
for name, feature in feature_by_name.items():
339-
if _include_in_parsing_spec(feature):
340-
feature_spec[name], domains[name] = _feature_as_feature_spec(
341-
feature, string_domains)
342-
343-
schema_utils_legacy.check_for_unsupported_features(schema_proto)
344-
345-
domains = {
346-
name: domain for name, domain in domains.items() if domain is not None
347-
}
316+
feature_by_name = {feature.name: feature for feature in schema_proto.feature}
317+
for name, tensor_representation in tensor_representations.items():
318+
value_feature = str(
319+
tensor_representation_util.GetSourceValueColumnFromTensorRepresentation(
320+
tensor_representation))
321+
spec = (
322+
tensor_representation_util.CreateTfExampleParserConfig(
323+
tensor_representation, feature_by_name[value_feature].type))
324+
if isinstance(spec, tf.io.FixedLenFeature):
325+
spec = _standardize_default_value(spec)
326+
feature_spec[name] = spec
327+
domain = _get_domain(feature_by_name[value_feature], string_domains)
328+
if domain is not None:
329+
domains[name] = domain
348330
return SchemaAsFeatureSpecResult(feature_spec, domains)
349331

350332

@@ -495,115 +477,6 @@ def _ragged_tensor_representation_as_feature_spec(
495477
return typing.cast(common_types.RaggedFeature, spec), domain
496478

497479

498-
def _sparse_feature_as_feature_spec(feature, feature_by_name, string_domains):
499-
"""Returns a representation of a SparseFeature as a feature spec."""
500-
index_keys = [index_feature.name for index_feature in feature.index_feature]
501-
index_features = []
502-
for index_key in index_keys:
503-
try:
504-
index_features.append(feature_by_name.pop(index_key))
505-
except KeyError:
506-
raise ValueError(
507-
'sparse_feature "{}" referred to index feature "{}" which did not '
508-
'exist in the schema'.format(feature.name, index_key))
509-
510-
value_key = feature.value_feature.name
511-
try:
512-
value_feature = feature_by_name.pop(value_key)
513-
except KeyError:
514-
raise ValueError(
515-
'sparse_feature "{}" referred to value feature "{}" which did not '
516-
'exist in the schema or was referred to as an index or value multiple '
517-
'times.'.format(feature.name, value_key))
518-
519-
shape = []
520-
for index_feature, index_key in zip(index_features, index_keys):
521-
if index_feature.HasField('int_domain'):
522-
# Currently we only handle O-based INT index features whose minimum
523-
# domain value must be zero.
524-
if not index_feature.int_domain.HasField('min'):
525-
raise ValueError('Cannot determine dense shape of sparse feature '
526-
'"{}". The minimum domain value of index feature "{}"'
527-
' is not set.'.format(feature.name, index_key))
528-
if index_feature.int_domain.min != 0:
529-
raise ValueError('Only 0-based index features are supported. Sparse '
530-
'feature "{}" has index feature "{}" whose minimum '
531-
'domain value is {}.'.format(
532-
feature.name, index_key,
533-
index_feature.int_domain.min))
534-
535-
if not index_feature.int_domain.HasField('max'):
536-
raise ValueError('Cannot determine dense shape of sparse feature '
537-
'"{}". The maximum domain value of index feature "{}"'
538-
' is not set.'.format(feature.name, index_key))
539-
shape.append(index_feature.int_domain.max + 1)
540-
elif len(index_keys) == 1:
541-
raise ValueError('Cannot determine dense shape of sparse feature "{}".'
542-
' The index feature "{}" had no int_domain set.'.format(
543-
feature.name, index_key))
544-
else:
545-
shape.append(-1)
546-
547-
dtype = _feature_dtype(value_feature)
548-
if len(index_keys) != len(shape):
549-
raise ValueError(
550-
'sparse_feature "{}" had rank {} (shape {}) but {} index keys were'
551-
' given'.format(feature.name, len(shape), shape, len(index_keys)))
552-
spec = tf.io.SparseFeature(index_keys, value_key, dtype, shape,
553-
feature.is_sorted)
554-
domain = _get_domain(value_feature, string_domains)
555-
return spec, domain
556-
557-
558-
def _feature_as_feature_spec(feature, string_domains):
559-
"""Returns a representation of a Feature as a feature spec."""
560-
dtype = _feature_dtype(feature)
561-
if feature.HasField('shape'):
562-
if feature.presence.min_fraction != 1:
563-
raise ValueError(
564-
'Feature "{}" had shape {} set but min_fraction {} != 1. Use'
565-
' value_count not shape field when min_fraction != 1.'.format(
566-
feature.name, feature.shape, feature.presence.min_fraction))
567-
spec = tf.io.FixedLenFeature(
568-
_fixed_shape_as_tf_shape(feature.shape), dtype, default_value=None)
569-
else:
570-
spec = tf.io.VarLenFeature(dtype)
571-
domain = _get_domain(feature, string_domains)
572-
return spec, domain
573-
574-
575-
def _feature_dtype(feature):
576-
"""Returns a representation of a Feature's type as a tensorflow dtype."""
577-
if feature.type == schema_pb2.BYTES:
578-
return tf.string
579-
elif feature.type == schema_pb2.INT:
580-
return tf.int64
581-
elif feature.type == schema_pb2.FLOAT:
582-
return tf.float32
583-
else:
584-
raise ValueError('Feature "{}" had invalid type {}'.format(
585-
feature.name, schema_pb2.FeatureType.Name(feature.type)))
586-
587-
588-
def _fixed_shape_as_tf_shape(fixed_shape):
589-
"""Returns a representation of a FixedShape as a tensorflow shape."""
590-
# TODO(b/120869660): Remove the cast to int. Casting to int is currently
591-
# needed as some TF code explicitly checks for `int` and does not allow `long`
592-
# in tensor shapes.
593-
return [int(dim.size) for dim in fixed_shape.dim]
594-
595-
596-
_IGNORED_LIFECYCLE_STAGES = [
597-
schema_pb2.DEPRECATED, schema_pb2.DISABLED, schema_pb2.PLANNED,
598-
schema_pb2.ALPHA, schema_pb2.DEBUG_ONLY, schema_pb2.VALIDATION_DERIVED,
599-
]
600-
601-
602-
def _include_in_parsing_spec(feature):
603-
return not (schema_utils_legacy.get_deprecated(feature) or
604-
feature.lifecycle_stage in _IGNORED_LIFECYCLE_STAGES)
605-
606-
607480
def _legacy_schema_from_feature_spec(feature_spec, domains=None):
608481
"""Infer a Schema from a feature spec, using the legacy feature spec logic.
609482
@@ -668,134 +541,3 @@ def _legacy_schema_from_feature_spec(feature_spec, domains=None):
668541
_set_domain(name, feature, domains.get(name))
669542

670543
return result
671-
672-
673-
def _legacy_schema_as_feature_spec(schema_proto):
674-
"""Generate a feature spec and domains using legacy feature spec."""
675-
feature_spec = {}
676-
# Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For
677-
# sparse features, will hold the domain_info of the values feature. Features
678-
# that do not have a domain set will not be present in `domains`.
679-
domains = {}
680-
feature_by_name = {feature.name: feature for feature in schema_proto.feature}
681-
string_domains = _get_string_domains(schema_proto)
682-
683-
for name, feature in feature_by_name.items():
684-
if _include_in_parsing_spec(feature):
685-
feature_spec[name] = _legacy_feature_as_feature_spec(feature)
686-
domain = _get_domain(feature, string_domains)
687-
if domain is not None:
688-
domains[name] = domain
689-
690-
return SchemaAsFeatureSpecResult(feature_spec, domains)
691-
692-
693-
def _legacy_feature_as_feature_spec(feature):
694-
"""Translate a Feature proto into a TensorFlow feature spec.
695-
696-
This function applies heuristics to deduce the shape and other information
697-
from a FeatureProto. The FeatureProto contains information about the feature
698-
in an ExampleProto, but the feature spec proto also requires enough
699-
information to parse the feature into a tensor. We apply the following rules:
700-
701-
1. The dtype is determined from the feature's type according to the mapping
702-
BYTES -> string, INT -> int64, FLOAT -> float32. TYPE_UNKNOWN or any
703-
other type results in a ValueError.
704-
705-
2. The shape and representation of the column are determined by the
706-
following rules:
707-
* if the value_count.min and value_count.max are both 1 then the shape
708-
is scalar and the representation is fixed length.
709-
* If value_count.min and value_count.max are equal but greater than 1,
710-
then the shape is a vector whose length is value_count.max and the
711-
representation is fixed length.
712-
* If value_count.min and value_count.max are equal and are less than 1,
713-
then the shape is a vector of unknown length and the representation
714-
is variable length.
715-
* If value_count.min and value_count.max are not equal then
716-
the shape is a vector of unknown length and the representation is
717-
variable length.
718-
719-
3. If the feature is always present or is variable length (based on the
720-
above rule), no default value is set but if the feature is not always
721-
present and is fixed length, then a canonical default value is chosen
722-
based on _DEFAULT_VALUE_FOR_DTYPE.
723-
724-
4. Features that are deprecated are completely ignored and removed.
725-
726-
Args:
727-
feature: A FeatureProto
728-
729-
Returns:
730-
A `tf.FixedLenFeature` or `tf.VarLenFeature`.
731-
732-
Raises:
733-
ValueError: If the feature's type is not supported or the schema is invalid.
734-
"""
735-
# Infer canonical tensorflow dtype.
736-
dtype = _feature_dtype(feature)
737-
738-
if feature.value_count.min < 0:
739-
raise ValueError(
740-
'Feature "{}" has value_count.min < 0 (value was {}).'.format(
741-
feature.name, feature.value_count.min))
742-
743-
if feature.value_count.max < 0:
744-
raise ValueError(
745-
'Feature "{}" has value_count.max < 0 (value was {}).'.format(
746-
feature.name, feature.value_count.max))
747-
748-
# Use heuristics to infer the shape and representation.
749-
if (feature.value_count.min == feature.value_count.max and
750-
feature.value_count.min == 1):
751-
# Case 1: value_count.min == value_count.max == 1. Infer a FixedLenFeature
752-
# with rank 0 and a default value.
753-
tf.compat.v1.logging.info(
754-
'Features %s has value_count.min == value_count.max == 1. Setting to '
755-
'fixed length scalar.', feature.name)
756-
default_value = _legacy_infer_default_value(feature, dtype)
757-
return tf.io.FixedLenFeature([], dtype, default_value)
758-
759-
elif (feature.value_count.min == feature.value_count.max and
760-
feature.value_count.min > 1):
761-
# Case 2: value_count.min == value_count.max > 1. Infer a FixedLenFeature
762-
# with rank 1 and a default value.
763-
tf.compat.v1.logging.info(
764-
'Feature %s has value_count.min == value_count.max > 1. Setting to '
765-
'fixed length vector.', feature.name)
766-
default_value = _legacy_infer_default_value(feature, dtype)
767-
return tf.io.FixedLenFeature([feature.value_count.min], dtype,
768-
default_value)
769-
770-
else:
771-
# Case 3: Either value_count.min != value_count.max or
772-
# value_count.min == value_count.max == 0. Infer a VarLenFeature.
773-
tf.compat.v1.logging.info(
774-
'Feature %s has value_count.min != value_count.max or '
775-
' value_count.min == value_count.max == 0. Setting to variable length '
776-
' vector.', feature.name)
777-
return tf.io.VarLenFeature(dtype)
778-
779-
780-
# For numeric values, set defaults that are less likely to occur in the actual
781-
# data so that users can test for missing values.
782-
_LEGACY_DEFAULT_VALUE_FOR_DTYPE = {tf.string: '', tf.int64: -1, tf.float32: -1}
783-
784-
785-
def _legacy_infer_default_value(feature_proto, dtype):
786-
"""Returns a canonical default value if min_fraction < 1 or else None."""
787-
if feature_proto.presence.min_fraction < 1:
788-
default_value = _LEGACY_DEFAULT_VALUE_FOR_DTYPE[dtype]
789-
tf.compat.v1.logging.info(
790-
'Feature %s has min_fraction (%f) != 1. Setting default value %r',
791-
feature_proto.name, feature_proto.presence.min_fraction, default_value)
792-
if feature_proto.value_count.min == 1:
793-
# neglecting vector of size 1 because that never happens.
794-
return default_value
795-
else:
796-
return [default_value] * feature_proto.value_count.min
797-
else:
798-
tf.compat.v1.logging.info(
799-
'Feature %s has min_fraction = 1 (%s). Setting default value to None.',
800-
feature_proto.name, feature_proto.presence)
801-
return None

tensorflow_transform/tf_metadata/schema_utils_legacy.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,3 @@ def set_generate_legacy_feature_spec(schema_proto, value):
2525
raise NotImplementedError(
2626
'The generate_legacy_feature_spec is a legacy field that is not part '
2727
'of the OSS tf.Transform codebase')
28-
29-
30-
def get_generate_legacy_feature_spec(schema_proto):
31-
del schema_proto # unused
32-
return False
33-
34-
35-
def check_for_unsupported_features(schema_proto):
36-
del schema_proto # unused
37-
38-
39-
def get_deprecated(feature):
40-
del feature # unused
41-
return False

0 commit comments

Comments
 (0)