Skip to content

Commit 96f96e6

Browse files
zoyahavtfx-copybara
authored andcommitted
Make annotate_sparse_output_shape support input shapes as a tensor, and simplify shape storage in collections by saving tensors.
This change also updates the census example to output one-hot tensors as sparse for efficiency. PiperOrigin-RevId: 510962121
1 parent b5ae891 commit 96f96e6

File tree

6 files changed

+83
-58
lines changed

6 files changed

+83
-58
lines changed

examples/census_example.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
# Functions for training
2727

2828

29+
def _make_inputs_dense(transformed_features):
30+
return {
31+
k: tf.sparse.to_dense(v) if isinstance(v, tf.SparseTensor) else v
32+
for k, v in transformed_features.items()
33+
}
34+
# pylint: disable=g-deprecated-tf-checker
35+
36+
2937
def _make_training_input_fn(tf_transform_output, transformed_examples,
3038
batch_size):
3139
"""Creates an input function reading from transformed data.
@@ -47,8 +55,9 @@ def input_fn():
4755
reader=tf.data.TFRecordDataset,
4856
shuffle=True)
4957

50-
transformed_features = tf.compat.v1.data.make_one_shot_iterator(
51-
dataset).get_next()
58+
transformed_features = _make_inputs_dense(
59+
tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
60+
)
5261

5362
# Extract features and label from the transformed tensors.
5463
# TODO(b/30367437): make transformed_labels a dict.
@@ -86,8 +95,9 @@ def serving_input_fn():
8695
# Apply the transform function that was used to generate the materialized
8796
# data.
8897
raw_features = serving_input_receiver.features
89-
transformed_features = tf_transform_output.transform_raw_features(
90-
raw_features)
98+
transformed_features = _make_inputs_dense(
99+
tf_transform_output.transform_raw_features(raw_features)
100+
)
91101

92102
return tf_estimator.export.ServingInputReceiver(
93103
transformed_features, serving_input_receiver.receiver_tensors)
@@ -106,8 +116,13 @@ def get_feature_columns(tf_transform_output):
106116
"""
107117
feature_spec = tf_transform_output.transformed_feature_spec()
108118
# Wrap scalars as real valued columns.
119+
def get_shape(spec):
120+
if isinstance(spec, tf.io.SparseFeature):
121+
return spec.size
122+
return spec.shape
123+
109124
return [
110-
tf.feature_column.numeric_column(key, shape=feature_spec[key].shape)
125+
tf.feature_column.numeric_column(key, shape=get_shape(feature_spec[key]))
111126
for key in (common.NUMERIC_FEATURE_KEYS + common.CATEGORICAL_FEATURE_KEYS)
112127
]
113128

examples/census_example_common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,16 @@ def preprocessing_fn(inputs):
140140
one_hot_encoded = tf.one_hot(
141141
integerized,
142142
depth=tf.cast(depth, tf.int32),
143-
on_value=1.0,
144-
off_value=0.0)
145-
# This output is now one-hot encoded. If saving transformed data to disk,
146-
# this can incur significant memory cost.
147-
outputs[key] = tf.reshape(one_hot_encoded, [-1, depth])
143+
on_value=1,
144+
off_value=0,
145+
dtype=tf.int64)
146+
# Saving one-hot encoded outputs as sparse in order to avoid large dense
147+
# (mostly empty) tensors. This is especially important when saving
148+
# transformed data to disk.
149+
outputs[key] = tf.sparse.from_dense(
150+
tf.reshape(one_hot_encoded, [-1, depth])
151+
)
152+
tft.experimental.annotate_sparse_output_shape(outputs[key], depth)
148153

149154
# For the label column we provide the mapping from string to index.
150155
table_keys = ['>50K', '<=50K']

examples/census_example_v2.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,28 @@ def train_and_evaluate(raw_train_eval_data_path_pattern,
187187
feature_spec.pop(common.LABEL_KEY)
188188

189189
inputs = {}
190+
sparse_inputs = {}
191+
dense_inputs = {}
190192
for key, spec in feature_spec.items():
191193
if isinstance(spec, tf.io.FixedLenFeature):
192194
# TODO(b/208879020): Move into schema such that spec.shape is [1] and not
193195
# [] for scalars.
194196
inputs[key] = tf.keras.layers.Input(
195197
shape=spec.shape or [1], name=key, dtype=spec.dtype)
198+
dense_inputs[key] = inputs[key]
199+
elif isinstance(spec, tf.io.SparseFeature):
200+
inputs[key] = tf.keras.layers.Input(
201+
shape=spec.size, name=key, dtype=spec.dtype, sparse=True
202+
)
203+
sparse_inputs[key] = inputs[key]
196204
else:
197205
raise ValueError('Spec type is not supported: ', key, spec)
198206

199-
stacked_inputs = tf.concat(tf.nest.flatten(inputs), axis=1)
207+
outputs = [
208+
tf.keras.layers.Dense(10, activation='relu')(x)
209+
for x in tf.nest.flatten(sparse_inputs)
210+
]
211+
stacked_inputs = tf.concat(tf.nest.flatten(dense_inputs) + outputs, axis=1)
200212
output = tf.keras.layers.Dense(100, activation='relu')(stacked_inputs)
201213
output = tf.keras.layers.Dense(70, activation='relu')(output)
202214
output = tf.keras.layers.Dense(50, activation='relu')(output)

tensorflow_transform/beam/annotators_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def preprocessing_fn(inputs):
3636
outputs = inputs.copy()
3737
x = tf.sparse.expand_dims(inputs['x'], -1)
3838
outputs['x'] = x
39-
tft.experimental.annotate_sparse_output_shape(x, [1, 1])
39+
tft.experimental.annotate_sparse_output_shape(x, tf.constant([1, 1]))
4040
tft.experimental.annotate_sparse_output_shape(outputs['y'], [17])
4141
tft.experimental.annotate_true_sparse_output(outputs['z'])
4242
return outputs

tensorflow_transform/experimental/annotators.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Experimental APIs to get annotations."""
1515

16-
from typing import Sequence
16+
from typing import Sequence, Union
1717

1818
import tensorflow as tf
1919
from tensorflow_transform import annotators
@@ -87,26 +87,37 @@ def get_vocabulary_size_by_name(vocab_filename: str) -> tf.Tensor:
8787
return result
8888

8989

90-
def annotate_sparse_output_shape(tensor: tf.SparseTensor, shape: Sequence[int]):
90+
def annotate_sparse_output_shape(
91+
tensor: tf.SparseTensor, shape: Union[Sequence[int], tf.Tensor]):
9192
"""Annotates a sparse output to have a given dense_shape.
9293
9394
Args:
9495
tensor: An `SparseTensor` to be annotated.
9596
shape: A dense_shape to annotate `tensor` with. Note that this shape does
9697
not include batch_size.
9798
"""
98-
if len(shape) != tensor.shape.rank - 1:
99+
if not isinstance(shape, tf.Tensor):
100+
if (tensor.shape.rank > 1 and tensor.shape.rank - 1 != len(shape)) or (
101+
tensor.shape.rank == 1 and len(shape) != 1):
102+
raise ValueError(
103+
f'Annotated shape {shape} was expected to have rank'
104+
f' {tensor.shape.rank - 1}')
105+
if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)):
106+
raise ValueError(
107+
f'Shape {shape} cannot contain annotated tensor {tensor}')
108+
shape = tf.convert_to_tensor(shape, dtype=tf.int64)
109+
elif shape.shape.rank > 1 or (
110+
shape.shape.rank == 1 and shape.shape[0] != tensor.shape.rank - 1):
99111
raise ValueError(
100-
f'Annotated shape {shape} was expected to have rank'
101-
f' {tensor.shape.rank - 1}'
102-
)
103-
if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)):
104-
raise ValueError(f'Shape {shape} cannot contain annotated tensor {tensor}')
112+
f'Annotation shape has rank {shape.shape.rank} but expected to have'
113+
f' rank {tensor.shape.rank - 1}')
114+
if shape.shape.rank < 1:
115+
shape = tf.expand_dims(shape, -1)
105116
# There's currently no way to override SparseTensor.dense_shape directly,
106117
# unless composing and returning a new SparseTensor.
107-
tensor._dense_shape = tf.convert_to_tensor( # pylint: disable=protected-access
108-
[tensor.dense_shape[0]] + list(shape), dtype=tf.int64
109-
)
118+
tensor._dense_shape = tf.concat( # pylint: disable=protected-access
119+
[tf.expand_dims(tensor.dense_shape[0], -1), tf.cast(shape, tf.int64)],
120+
axis=0)
110121
schema_inference.annotate_sparse_output_shape(tensor, shape)
111122

112123

tensorflow_transform/schema_inference.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
"""
2121

2222
import collections
23-
import functools
2423
import itertools
25-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
24+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
2625
from absl import logging
2726

2827
import tensorflow as tf
@@ -188,7 +187,7 @@ def infer_feature_schema(
188187
tensor_ranges = session.run(tensor_ranges)
189188
tensor_annotations, global_annotations = _get_schema_annotations(
190189
graph, session)
191-
sparse_output_annotations = _get_sparse_output_annotations_v1(graph)
190+
sparse_output_annotations = _get_sparse_output_annotations_v1(graph, session)
192191
modified_sparse_output_annotations = {}
193192
modified_tensor_ranges = {}
194193
feature_annotations = {}
@@ -790,24 +789,14 @@ def metadata_fn():
790789
return module.metadata_fn
791790

792791
_ANNOTATED_SPARSE_SHAPE_TENSORS = 'annotated_sparse_shape_tensors'
793-
_ANNOTATED_SPARSE_SHAPE_RANKS = 'annotated_sparse_shape_ranks'
794-
_ANNOTATED_SPARSE_SHAPE_DIMENSIONS = 'annotated_sparse_shape_dimensions'
792+
_ANNOTATED_SPARSE_SHAPES = 'annotated_sparse_shape_dimensions'
795793
_ANNOTATED_TRUELY_SPARSE_TENSORS = 'annotated_truely_sparse_tensors'
796794

797795

798-
def annotate_sparse_output_shape(tensor: tf.SparseTensor, shape: Sequence[int]):
796+
def annotate_sparse_output_shape(tensor: tf.SparseTensor, shape: tf.Tensor):
799797
"""Annotates a sparse output with a given shape."""
800-
if tensor.shape.rank - 1 != len(shape):
801-
raise ValueError(
802-
f'Output {tensor} was annotated with an incompatible shape: {shape}'
803-
)
804798
tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS, tensor.values)
805-
# We store rank and dimensions to separate collections since TF collections
806-
# don't allow storing lists. This can be simplified if we switch away from
807-
# collections.
808-
tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPE_RANKS, len(shape))
809-
for dim in shape:
810-
tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPE_DIMENSIONS, dim)
799+
tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPES, shape)
811800

812801

813802
def annotate_true_sparse_output(tensor: tf.SparseTensor):
@@ -818,26 +807,16 @@ def annotate_true_sparse_output(tensor: tf.SparseTensor):
818807

819808

820809
def _extract_true_sparse_annotations(
821-
graph: tf.compat.v1.Graph,
822-
) -> List[tf.Tensor]:
810+
graph: tf.compat.v1.Graph) -> List[tf.Tensor]:
823811
"""Extracts true sparse annotations from the graph."""
824812
return graph.get_collection(_ANNOTATED_TRUELY_SPARSE_TENSORS)
825813

826814

827815
def _extract_sparse_output_annotations(
828-
graph: tf.compat.v1.Graph,
829-
) -> List[Tuple[tf.Tensor, List[tf.Tensor]]]:
816+
graph: tf.compat.v1.Graph) -> List[Tuple[tf.Tensor, List[tf.Tensor]]]:
830817
"""Extracts sparse output annotations from the graph."""
831818
tensors = graph.get_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS)
832-
ranks = graph.get_collection(_ANNOTATED_SPARSE_SHAPE_RANKS)
833-
assert len(tensors) == len(ranks), f'{tensors} != {ranks}'
834-
shape_flattened = graph.get_collection(_ANNOTATED_SPARSE_SHAPE_DIMENSIONS)
835-
836-
# Splitting dimensions per annotated tensor.
837-
splits = functools.reduce(lambda lst, x: lst + [lst[-1] + x], ranks, [0])
838-
# Composing the annotated shape per tensor.
839-
shapes = tuple(shape_flattened[s:e] for s, e in zip(splits[:-1], splits[1:]))
840-
819+
shapes = graph.get_collection(_ANNOTATED_SPARSE_SHAPES)
841820
assert len(tensors) == len(shapes), f'{tensors} != {shapes}'
842821
return list(zip(tensors, shapes))
843822

@@ -852,7 +831,7 @@ def _get_sparse_output_annotations(
852831
return list(
853832
itertools.chain(
854833
(
855-
(a, [''])
834+
(a, tf.constant(['']))
856835
for a in _extract_true_sparse_annotations(graph)
857836
if a.ref() not in annotated_refs
858837
),
@@ -862,12 +841,15 @@ def _get_sparse_output_annotations(
862841

863842

864843
def _get_sparse_output_annotations_v1(
865-
graph: tf.compat.v1.Graph,
844+
graph: tf.compat.v1.Graph, session: Optional[tf.compat.v1.Session]
866845
) -> Dict[Any, List[Union[str, tf.Tensor]]]:
867-
return {
868-
tf_utils.hashable_tensor_or_op(kv[0]): kv[1]
869-
for kv in _get_sparse_output_annotations(graph)
870-
}
846+
if not session:
847+
return {}
848+
else:
849+
return {
850+
tf_utils.hashable_tensor_or_op(kv[0]): session.run(kv[1])
851+
for kv in _get_sparse_output_annotations(graph)
852+
}
871853

872854

873855
def _get_sparse_output_annotations_v2(

0 commit comments

Comments
 (0)