Skip to content

Commit 935fcf0

Browse files
zoyahavtfx-copybara
authored andcommitted
Implement output sparse tensor annotations: for setting a dense_shape and force representing 2d sparse tensors as sparse as opposed to varlen.
PiperOrigin-RevId: 508646165
1 parent b1c94ed commit 935fcf0

File tree

4 files changed

+462
-29
lines changed

4 files changed

+462
-29
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
* `RaggedTensor`s can now be automatically inferred for variable length
88
features by setting `represent_variable_length_as_ragged=true` in TFMD
99
schema.
10+
* New experimental APIs added for annotating sparse output tensors:
11+
`tft.experimental.annotate_sparse_output_shape` and
12+
`tft.experimental.annotate_true_sparse_output`.
1013

1114
## Bug Fixes and Other Changes
1215

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# coding=utf-8
2+
#
3+
# Copyright 2023 Google Inc. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Tests for tft annotators."""
17+
18+
import tensorflow as tf
19+
import tensorflow_transform as tft
20+
from tensorflow_transform.beam import tft_unit
21+
from google.protobuf import text_format
22+
from tensorflow_metadata.proto.v0 import schema_pb2
23+
24+
25+
_TF_VERSION_NAMED_PARAMETERS = [
26+
dict(testcase_name='CompatV1', use_tf_compat_v1=True),
27+
dict(testcase_name='V2', use_tf_compat_v1=False),
28+
]
29+
30+
31+
class AnnotatorsTest(tft_unit.TransformTestCase):
32+
33+
@tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS)
34+
def test_annotate_sparse_outputs(self, use_tf_compat_v1):
35+
def preprocessing_fn(inputs):
36+
outputs = inputs.copy()
37+
x = tf.sparse.expand_dims(inputs['x'], -1)
38+
outputs['x'] = x
39+
tft.experimental.annotate_sparse_output_shape(x, [1, 1])
40+
tft.experimental.annotate_sparse_output_shape(outputs['y'], [17])
41+
tft.experimental.annotate_true_sparse_output(outputs['z'])
42+
return outputs
43+
44+
input_data_dicts = [dict(x=[1], y=[2], z=[3], t=[4]) for x in range(10)]
45+
input_metadata = tft.DatasetMetadata.from_feature_spec({
46+
'x': tf.io.VarLenFeature(tf.int64),
47+
'y': tf.io.VarLenFeature(tf.int64),
48+
'z': tf.io.VarLenFeature(tf.int64),
49+
't': tf.io.VarLenFeature(tf.int64),
50+
})
51+
schema = text_format.Parse(
52+
"""
53+
feature {
54+
name: "t"
55+
type: INT
56+
}
57+
feature {
58+
name: "x$sparse_indices_0"
59+
type: INT
60+
int_domain {
61+
min: 0
62+
max: 0
63+
}
64+
}
65+
feature {
66+
name: "x$sparse_indices_1"
67+
type: INT
68+
int_domain {
69+
min: 0
70+
max: 0
71+
}
72+
}
73+
feature {
74+
name: "x$sparse_values"
75+
type: INT
76+
}
77+
feature {
78+
name: "y$sparse_indices_0"
79+
type: INT
80+
int_domain {
81+
min: 0
82+
max: 16
83+
}
84+
}
85+
feature {
86+
name: "y$sparse_values"
87+
type: INT
88+
}
89+
feature {
90+
name: "z$sparse_indices_0"
91+
type: INT
92+
}
93+
feature {
94+
name: "z$sparse_values"
95+
type: INT
96+
}
97+
sparse_feature {
98+
name: "x"
99+
index_feature {
100+
name: "x$sparse_indices_0"
101+
}
102+
index_feature {
103+
name: "x$sparse_indices_1"
104+
}
105+
is_sorted: true
106+
value_feature {
107+
name: "x$sparse_values"
108+
}
109+
}
110+
sparse_feature {
111+
name: "y"
112+
index_feature {
113+
name: "y$sparse_indices_0"
114+
}
115+
is_sorted: true
116+
value_feature {
117+
name: "y$sparse_values"
118+
}
119+
}
120+
sparse_feature {
121+
name: "z"
122+
index_feature {
123+
name: "z$sparse_indices_0"
124+
}
125+
is_sorted: true
126+
value_feature {
127+
name: "z$sparse_values"
128+
}
129+
}
130+
""",
131+
schema_pb2.Schema(),
132+
)
133+
if not tft_unit.is_external_environment():
134+
schema.generate_legacy_feature_spec = False
135+
self.assertAnalyzeAndTransformResults(
136+
input_data_dicts,
137+
input_metadata,
138+
preprocessing_fn,
139+
expected_metadata=tft.DatasetMetadata(schema),
140+
force_tf_compat_v1=use_tf_compat_v1,
141+
output_record_batches=True,
142+
)
143+
144+
@tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS)
145+
def test_conflicting_sparse_outputs_annotations(self, use_tf_compat_v1):
146+
def preprocessing_fn(inputs):
147+
tft.experimental.annotate_sparse_output_shape(inputs['x'], [3])
148+
tft.experimental.annotate_sparse_output_shape(inputs['x'], [17])
149+
tft.experimental.annotate_true_sparse_output(inputs['x'])
150+
return inputs
151+
152+
input_data_dicts = [dict(x=[1]) for x in range(10)]
153+
input_metadata = tft.DatasetMetadata.from_feature_spec(
154+
{
155+
'x': tf.io.VarLenFeature(tf.int64),
156+
}
157+
)
158+
schema = text_format.Parse(
159+
"""
160+
feature {
161+
name: "x$sparse_indices_0"
162+
type: INT
163+
int_domain {
164+
min: 0
165+
max: 16
166+
}
167+
}
168+
feature {
169+
name: "x$sparse_values"
170+
type: INT
171+
}
172+
sparse_feature {
173+
name: "x"
174+
index_feature {
175+
name: "x$sparse_indices_0"
176+
}
177+
is_sorted: true
178+
value_feature {
179+
name: "x$sparse_values"
180+
}
181+
}
182+
""",
183+
schema_pb2.Schema(),
184+
)
185+
if not tft_unit.is_external_environment():
186+
schema.generate_legacy_feature_spec = False
187+
self.assertAnalyzeAndTransformResults(
188+
input_data_dicts,
189+
input_metadata,
190+
preprocessing_fn,
191+
expected_metadata=tft.DatasetMetadata(schema),
192+
force_tf_compat_v1=use_tf_compat_v1,
193+
output_record_batches=True,
194+
)
195+
196+
@tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS)
197+
def test_invalid_sparse_outputs_annotations(self, use_tf_compat_v1):
198+
def preprocessing_fn(inputs):
199+
tft.experimental.annotate_sparse_output_shape(inputs['x'], [3, 42])
200+
return inputs
201+
202+
input_data_dicts = [dict(x=[1]) for x in range(10)]
203+
input_metadata = tft.DatasetMetadata.from_feature_spec(
204+
{
205+
'x': tf.io.VarLenFeature(tf.int64),
206+
}
207+
)
208+
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
209+
ValueError,
210+
r'Annotated shape \[3, 42\] was expected to have rank 1',
211+
):
212+
self.assertAnalyzeAndTransformResults(
213+
input_data_dicts,
214+
input_metadata,
215+
preprocessing_fn,
216+
force_tf_compat_v1=use_tf_compat_v1,
217+
)
218+
219+
220+
if __name__ == '__main__':
221+
tft_unit.main()

tensorflow_transform/experimental/annotators.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,22 @@
1313
# limitations under the License.
1414
"""Experimental APIs to get annotations."""
1515

16+
from typing import Sequence
17+
1618
import tensorflow as tf
1719
from tensorflow_transform import annotators
20+
from tensorflow_transform import schema_inference
1821

1922
from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
2023

2124

25+
__all__ = [
26+
'get_vocabulary_size_by_name',
27+
'annotate_sparse_output_shape',
28+
'annotate_true_sparse_output',
29+
]
30+
31+
2232
def get_vocabulary_size_by_name(vocab_filename: str) -> tf.Tensor:
2333
# pyformat: disable
2434
"""Gets the size of a vocabulary created using `tft.vocabulary`.
@@ -75,3 +85,31 @@ def get_vocabulary_size_by_name(vocab_filename: str) -> tf.Tensor:
7585
'`vocab_filename` argument passed to it.')
7686

7787
return result
88+
89+
90+
def annotate_sparse_output_shape(tensor: tf.SparseTensor, shape: Sequence[int]):
91+
"""Annotates a sparse output to have a given dense_shape.
92+
93+
Args:
94+
tensor: An `SparseTensor` to be annotated.
95+
shape: A dense_shape to annotate `tensor` with. Note that this shape does
96+
not include batch_size.
97+
"""
98+
if len(shape) != tensor.shape.rank - 1:
99+
raise ValueError(
100+
f'Annotated shape {shape} was expected to have rank'
101+
f' {tensor.shape.rank - 1}'
102+
)
103+
if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)):
104+
raise ValueError(f'Shape {shape} cannot contain annotated tensor {tensor}')
105+
# There's currently no way to override SparseTensor.dense_shape directly,
106+
# unless composing and returning a new SparseTensor.
107+
tensor._dense_shape = tf.convert_to_tensor( # pylint: disable=protected-access
108+
[tensor.dense_shape[0]] + list(shape), dtype=tf.int64
109+
)
110+
schema_inference.annotate_sparse_output_shape(tensor, shape)
111+
112+
113+
def annotate_true_sparse_output(tensor: tf.SparseTensor):
114+
"""Annotates a sparse output to be truely sparse and not varlen."""
115+
schema_inference.annotate_true_sparse_output(tensor)

0 commit comments

Comments
 (0)