Skip to content

Commit 55d1c8b

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Added element-wise scaling support to scale_by_z_score_per_key for key_vocabulary_filename = None
PiperOrigin-RevId: 457085189
1 parent 3d31835 commit 55d1c8b

File tree

4 files changed

+193
-9
lines changed

4 files changed

+193
-9
lines changed

RELEASE.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
## Major Features and Improvements
66

7-
* Adds element-wise scaling support to `scale_by_min_max_per_key` and
8-
`scale_to_0_1_per_key` for `key_vocabulary_filename = None`.
9-
7+
* Adds element-wise scaling support to `scale_by_min_max_per_key`,
8+
`scale_to_0_1_per_key` and `scale_to_z_score_per_key` for
9+
`key_vocabulary_filename = None`.
1010
## Bug Fixes and Other Changes
1111

1212
* Depends on `tensorflow>=1.15.5,<2` or `tensorflow>=2.9,<2.10`

tensorflow_transform/analyzers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,16 +1106,18 @@ def _mean_and_var_per_key(
11061106
if key is None:
11071107
raise ValueError('A non-None key is required for _mean_and_var_per_key')
11081108

1109-
if not reduce_instance_dims:
1110-
raise NotImplementedError('Per-key elementwise reduction not supported')
1109+
if not reduce_instance_dims and isinstance(
1110+
x, (tf.SparseTensor, tf.RaggedTensor)):
1111+
raise NotImplementedError(
1112+
'Per-key elementwise reduction of Composite Tensors not supported ')
11111113

11121114
with tf.compat.v1.name_scope('mean_and_var_per_key'):
11131115
x = tf.cast(x, output_dtype)
11141116

11151117
key_vocab, key_counts, key_means, key_variances = (
11161118
tf_utils.reduce_batch_count_mean_and_var_per_key(
11171119
x, key, reduce_instance_dims=reduce_instance_dims))
1118-
output_shape = ()
1120+
output_shape = () if reduce_instance_dims else x.get_shape()[1:]
11191121

11201122
combine_inputs = _WeightedMeanAndVarAccumulator(
11211123
count=key_counts,

tensorflow_transform/beam/impl_test.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,139 @@ def scale_to_z_score_per_key(tensor, key, var_name=''):
18091809
preprocessing_fn, expected_data,
18101810
expected_metadata)
18111811

1812+
@tft_unit.named_parameters(
1813+
dict(
1814+
testcase_name='_float',
1815+
input_data=[
1816+
{
1817+
'x': [-4, 0],
1818+
'key': 'a',
1819+
},
1820+
{
1821+
'x': [10, 0],
1822+
'key': 'a',
1823+
},
1824+
{
1825+
'x': [2, 0],
1826+
'key': 'a',
1827+
},
1828+
{
1829+
'x': [4, 0],
1830+
'key': 'a',
1831+
},
1832+
{
1833+
'x': [1, 0],
1834+
'key': 'b',
1835+
},
1836+
{
1837+
'x': [-1, 0],
1838+
'key': 'b',
1839+
},
1840+
{
1841+
'x': [np.nan, np.nan],
1842+
'key': 'b',
1843+
},
1844+
],
1845+
# Elementwise = True
1846+
# Mean [a, b] = [[ 3.0, 0.0], [0.0, 0.0]]
1847+
# Variance [a, b] = [[25.0, 0.0], [1.0, 0.0]]
1848+
# StdDev [a, b] = [[ 5.0, 0.0], [1.0, 0.0]]
1849+
expected_data=[
1850+
{
1851+
'x_scaled': [-1.4, 0.0], # [(-4 - 3) / 5, (0 - 0) / 0]
1852+
},
1853+
{
1854+
'x_scaled': [1.4, 0.0] # [(10 - 3) / 5, (0 - 0) / 0]
1855+
},
1856+
{
1857+
'x_scaled': [-0.2, 0.0] # [(2 - 3) / 5, (0 - 0) / 0]
1858+
},
1859+
{
1860+
'x_scaled': [0.2, 0.0], # [(4 - 3) / 5, (0 - 0) / 0]
1861+
},
1862+
{
1863+
'x_scaled': [1.0, 0.0] # [(1 - 0) / 1, (0 - 0) / 0]
1864+
},
1865+
{
1866+
'x_scaled': [-1.0, 0.0] # [(-1 - 0) / 1, (0 - 0) / 0]
1867+
},
1868+
{
1869+
'x_scaled': [np.nan, np.nan]
1870+
},
1871+
],
1872+
input_metadata=tft.DatasetMetadata.from_feature_spec({
1873+
'x': tf.io.FixedLenFeature([2], tf.float32),
1874+
'key': tf.io.FixedLenFeature([], tf.string),
1875+
}),
1876+
expected_metadata=tft.DatasetMetadata.from_feature_spec({
1877+
'x_scaled': tf.io.FixedLenFeature([2], tf.float32),
1878+
})),
1879+
dict(
1880+
testcase_name='float_3dims',
1881+
input_data=[
1882+
{
1883+
'x': [[-4, -8], [-12, -16]],
1884+
'key': 'a',
1885+
},
1886+
{
1887+
'x': [[10, 20], [30, 40]],
1888+
'key': 'a',
1889+
},
1890+
{
1891+
'x': [[2, 4], [6, 8]],
1892+
'key': 'a',
1893+
},
1894+
{
1895+
'x': [[4, 8], [12, 16]],
1896+
'key': 'a',
1897+
},
1898+
{
1899+
'x': [[1, 2], [3, 4]],
1900+
'key': 'b',
1901+
},
1902+
],
1903+
expected_data=[
1904+
{
1905+
'x_scaled': [[-1.4, -1.4], [-1.4, -1.4]],
1906+
},
1907+
{
1908+
'x_scaled': [[1.4, 1.4], [1.4, 1.4]],
1909+
},
1910+
{
1911+
'x_scaled': [[-0.2, -0.2], [-0.2, -0.2]],
1912+
},
1913+
{
1914+
'x_scaled': [[0.2, 0.2], [0.2, 0.2]],
1915+
},
1916+
{
1917+
'x_scaled': [[0.0, 0.0], [0.0, 0.0]],
1918+
},
1919+
],
1920+
input_metadata=tft.DatasetMetadata.from_feature_spec({
1921+
'x': tf.io.FixedLenFeature([2, 2], tf.float32),
1922+
'key': tf.io.FixedLenFeature([], tf.string),
1923+
}),
1924+
expected_metadata=tft.DatasetMetadata.from_feature_spec({
1925+
'x_scaled': tf.io.FixedLenFeature([2, 2], tf.float32),
1926+
})),
1927+
)
1928+
def testScaleToZScorePerKeyElementwise(self, input_data, expected_data,
1929+
input_metadata, expected_metadata):
1930+
1931+
def preprocessing_fn(inputs):
1932+
outputs = {}
1933+
outputs['x_scaled'] = tft.scale_to_z_score_per_key(
1934+
tf.cast(inputs['x'], tf.float32),
1935+
key=inputs['key'],
1936+
elementwise=True,
1937+
key_vocabulary_filename=None)
1938+
self.assertEqual(outputs['x_scaled'].dtype, tf.float32)
1939+
return outputs
1940+
1941+
self.assertAnalyzeAndTransformResults(input_data, input_metadata,
1942+
preprocessing_fn, expected_data,
1943+
expected_metadata)
1944+
18121945
@tft_unit.parameters(
18131946
(tf.int16,),
18141947
(tf.int32,),
@@ -1975,6 +2108,48 @@ def analyzer_fn(inputs):
19752108
expected_outputs,
19762109
desired_batch_size=10)
19772110

2111+
def testMeanAndVarPerKeyElementwise(self):
2112+
2113+
def analyzer_fn(inputs):
2114+
key_vocab, mean, var = analyzers._mean_and_var_per_key(
2115+
inputs['x'], inputs['key'], reduce_instance_dims=False)
2116+
return {
2117+
'key_vocab': key_vocab,
2118+
'mean': mean,
2119+
'var': tf.round(100 * var) / 100.0
2120+
}
2121+
2122+
input_data = input_data = [{
2123+
'x': [-4, -1],
2124+
'key': 'a',
2125+
}, {
2126+
'x': [10, 0],
2127+
'key': 'a',
2128+
}, {
2129+
'x': [2, 0],
2130+
'key': 'a',
2131+
}, {
2132+
'x': [4, -1],
2133+
'key': 'a',
2134+
}, {
2135+
'x': [10, 0],
2136+
'key': 'b',
2137+
}, {
2138+
'x': [0, 10],
2139+
'key': 'b',
2140+
}]
2141+
input_metadata = tft.DatasetMetadata.from_feature_spec({
2142+
'x': tf.io.FixedLenFeature([2], tf.float32),
2143+
'key': tf.io.FixedLenFeature([], tf.string)
2144+
})
2145+
expected_outputs = {
2146+
'key_vocab': np.array([b'a', b'b'], np.object),
2147+
'mean': np.array([[3.0, -0.5], [5.0, 5.0]], np.float32),
2148+
'var': np.array([[25.0, 0.25], [25.0, 25.0]], np.float32)
2149+
}
2150+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
2151+
expected_outputs)
2152+
19782153
@tft_unit.named_parameters(
19792154
dict(
19802155
testcase_name='_dense_2d',

tensorflow_transform/mappers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,15 @@ def _scale_to_z_score_internal(
619619
reduce_instance_dims=not elementwise,
620620
output_dtype=output_dtype)
621621
else:
622-
if elementwise:
623-
raise NotImplementedError('Per-key elementwise reduction not supported')
622+
if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
623+
raise NotImplementedError(
624+
'Per-key elementwise reduction of Composite Tensors not supported')
624625

625626
mean_and_var_per_key_result = analyzers._mean_and_var_per_key( # pylint: disable=protected-access
626-
x, key, key_vocabulary_filename=key_vocabulary_filename,
627+
x,
628+
key,
629+
reduce_instance_dims=not elementwise,
630+
key_vocabulary_filename=key_vocabulary_filename,
627631
output_dtype=output_dtype)
628632

629633
if key_vocabulary_filename is None:
@@ -633,6 +637,9 @@ def _scale_to_z_score_internal(
633637
x_mean, x_var = tf_utils.map_per_key_reductions(
634638
(key_means, key_vars), key, key_vocab, x, not elementwise)
635639
else:
640+
if elementwise:
641+
raise NotImplementedError(
642+
'Elementwise scaling does not support key_vocabulary_filename')
636643
mean_var_for_key = tf_utils.apply_per_key_vocabulary(
637644
mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims)
638645
x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1])

0 commit comments

Comments
 (0)