Skip to content

Commit 3d31835

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Added element-wise scaling support to scale_by_min_max_per_key and scale_to_0_1_per_key for key_vocabulary_filename = None
PiperOrigin-RevId: 456846710
1 parent 04064f3 commit 3d31835

File tree

3 files changed

+121
-58
lines changed

3 files changed

+121
-58
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
## Major Features and Improvements
66

7+
* Adds element-wise scaling support to `scale_by_min_max_per_key` and
8+
`scale_to_0_1_per_key` for `key_vocabulary_filename = None`.
9+
710
## Bug Fixes and Other Changes
811

912
* Depends on `tensorflow>=1.15.5,<2` or `tensorflow>=2.9,<2.10`

tensorflow_transform/beam/impl_test.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -788,13 +788,14 @@ def preprocessing_fn(inputs):
788788
preprocessing_fn, expected_data,
789789
expected_metadata)
790790

791-
def testScaleUnitIntervalPerKey(self):
791+
@tft_unit.parameters((True,), (False,))
792+
def testScaleUnitIntervalPerKey(self, elementwise):
792793

793794
def preprocessing_fn(inputs):
794795
outputs = {}
795796
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
796797
result = tft.scale_to_0_1_per_key(
797-
stacked_input, inputs['key'], elementwise=False)
798+
stacked_input, inputs['key'], elementwise)
798799
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
799800
return outputs
800801

@@ -828,25 +829,46 @@ def preprocessing_fn(inputs):
828829
'y': tf.io.FixedLenFeature([], tf.float32),
829830
'key': tf.io.FixedLenFeature([], tf.string)
830831
})
831-
expected_data = [{
832-
'x_scaled': 0.6,
833-
'y_scaled': 0.8
834-
}, {
835-
'x_scaled': 0.0,
836-
'y_scaled': 0.2
837-
}, {
838-
'x_scaled': 0.8,
839-
'y_scaled': 1.0
840-
}, {
841-
'x_scaled': 0.2,
842-
'y_scaled': 0.4
843-
}, {
844-
'x_scaled': 1.0,
845-
'y_scaled': 0.0
846-
}, {
847-
'x_scaled': 0.6,
848-
'y_scaled': 0.5
849-
}]
832+
if elementwise:
833+
expected_data = [{
834+
'x_scaled': 0.75,
835+
'y_scaled': 0.75
836+
}, {
837+
'x_scaled': 0.0,
838+
'y_scaled': 0.0
839+
}, {
840+
'x_scaled': 1.0,
841+
'y_scaled': 1.0
842+
}, {
843+
'x_scaled': 0.25,
844+
'y_scaled': 0.25
845+
}, {
846+
'x_scaled': 1.0,
847+
'y_scaled': 0.0
848+
}, {
849+
'x_scaled': 0.0,
850+
'y_scaled': 1.0
851+
}]
852+
else:
853+
expected_data = [{
854+
'x_scaled': 0.6,
855+
'y_scaled': 0.8
856+
}, {
857+
'x_scaled': 0.0,
858+
'y_scaled': 0.2
859+
}, {
860+
'x_scaled': 0.8,
861+
'y_scaled': 1.0
862+
}, {
863+
'x_scaled': 0.2,
864+
'y_scaled': 0.4
865+
}, {
866+
'x_scaled': 1.0,
867+
'y_scaled': 0.0
868+
}, {
869+
'x_scaled': 0.6,
870+
'y_scaled': 0.5
871+
}]
850872
expected_metadata = tft.DatasetMetadata.from_feature_spec({
851873
'x_scaled': tf.io.FixedLenFeature([], tf.float32),
852874
'y_scaled': tf.io.FixedLenFeature([], tf.float32)
@@ -919,14 +941,24 @@ def preprocessing_fn(inputs):
919941
expected_metadata)
920942

921943
@tft_unit.named_parameters(
922-
dict(testcase_name='_empty_filename',
923-
key_vocabulary_filename=''),
924-
dict(testcase_name='_nonempty_filename',
925-
key_vocabulary_filename='per_key'),
926-
dict(testcase_name='_none_filename',
927-
key_vocabulary_filename=None)
928-
)
929-
def testScaleMinMaxPerKey(self, key_vocabulary_filename):
944+
dict(
945+
testcase_name='_empty_filename',
946+
elementwise=False,
947+
key_vocabulary_filename=''),
948+
dict(
949+
testcase_name='_nonempty_filename',
950+
elementwise=False,
951+
key_vocabulary_filename='per_key'),
952+
dict(
953+
testcase_name='_none_filename',
954+
elementwise=False,
955+
key_vocabulary_filename=None),
956+
dict(
957+
testcase_name='_elementwise_none_filename',
958+
elementwise=True,
959+
key_vocabulary_filename=None))
960+
def testScaleMinMaxPerKey(self, elementwise, key_vocabulary_filename):
961+
930962
def preprocessing_fn(inputs):
931963
outputs = {}
932964
stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1)
@@ -935,7 +967,7 @@ def preprocessing_fn(inputs):
935967
inputs['key'],
936968
output_min=-1,
937969
output_max=1,
938-
elementwise=False,
970+
elementwise=elementwise,
939971
key_vocabulary_filename=key_vocabulary_filename)
940972
outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1)
941973
return outputs
@@ -970,37 +1002,61 @@ def preprocessing_fn(inputs):
9701002
'y': tf.io.FixedLenFeature([], tf.float32),
9711003
'key': tf.io.FixedLenFeature([], tf.string)
9721004
})
973-
974-
expected_data = [{
975-
'x_scaled': -0.25,
976-
'y_scaled': 0.75
977-
}, {
978-
'x_scaled': -1.0,
979-
'y_scaled': 0.0
980-
}, {
981-
'x_scaled': 0.0,
982-
'y_scaled': 1.0
983-
}, {
984-
'x_scaled': -0.75,
985-
'y_scaled': 0.25
986-
}, {
987-
'x_scaled': -1.0,
988-
'y_scaled': 0.0
989-
}, {
990-
'x_scaled': 0.0,
991-
'y_scaled': 1.0
992-
}]
1005+
if elementwise:
1006+
expected_data = [{
1007+
'x_scaled': 0.5,
1008+
'y_scaled': 0.5
1009+
}, {
1010+
'x_scaled': -1.0,
1011+
'y_scaled': -1.0
1012+
}, {
1013+
'x_scaled': 1.0,
1014+
'y_scaled': 1.0
1015+
}, {
1016+
'x_scaled': -0.5,
1017+
'y_scaled': -0.5
1018+
}, {
1019+
'x_scaled': -1.0,
1020+
'y_scaled': -1.0
1021+
}, {
1022+
'x_scaled': 1.0,
1023+
'y_scaled': 1.0
1024+
}]
1025+
else:
1026+
expected_data = [{
1027+
'x_scaled': -0.25,
1028+
'y_scaled': 0.75
1029+
}, {
1030+
'x_scaled': -1.0,
1031+
'y_scaled': 0.0
1032+
}, {
1033+
'x_scaled': 0.0,
1034+
'y_scaled': 1.0
1035+
}, {
1036+
'x_scaled': -0.75,
1037+
'y_scaled': 0.25
1038+
}, {
1039+
'x_scaled': -1.0,
1040+
'y_scaled': 0.0
1041+
}, {
1042+
'x_scaled': 0.0,
1043+
'y_scaled': 1.0
1044+
}]
9931045
expected_metadata = tft.DatasetMetadata.from_feature_spec({
9941046
'x_scaled': tf.io.FixedLenFeature([], tf.float32),
9951047
'y_scaled': tf.io.FixedLenFeature([], tf.float32)
9961048
})
9971049
if key_vocabulary_filename:
998-
per_key_vocab_contents = {key_vocabulary_filename:
999-
[(b'a', [-1.0, 9.0]), (b'b', [2.0, 2.0])]}
1050+
per_key_vocab_contents = {
1051+
key_vocabulary_filename: [(b'a', [-1.0, 9.0]), (b'b', [2.0, 2.0])]
1052+
}
10001053
else:
10011054
per_key_vocab_contents = None
10021055
self.assertAnalyzeAndTransformResults(
1003-
input_data, input_metadata, preprocessing_fn, expected_data,
1056+
input_data,
1057+
input_metadata,
1058+
preprocessing_fn,
1059+
expected_data,
10041060
expected_metadata,
10051061
expected_vocab_file_contents=per_key_vocab_contents)
10061062

tensorflow_transform/mappers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,13 @@ def _scale_by_min_max_internal(
367367
x,
368368
reduce_instance_dims=not elementwise)
369369
else:
370-
if elementwise:
371-
raise NotImplementedError('Per-key elementwise reduction not supported')
370+
if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
371+
raise NotImplementedError(
372+
'Per-key elementwise reduction of Composite Tensors not supported')
372373
key_values = analyzers._min_and_max_per_key( # pylint: disable=protected-access
373374
x,
374375
key,
375-
reduce_instance_dims=True,
376+
reduce_instance_dims=not elementwise,
376377
key_vocabulary_filename=key_vocabulary_filename)
377378
if key_vocabulary_filename is None:
378379
key_vocab, min_x_value, max_x_value = key_values
@@ -381,10 +382,13 @@ def _scale_by_min_max_internal(
381382
min_x_value, max_x_value = tf_utils.map_per_key_reductions(
382383
(min_x_value, max_x_value), key, key_vocab, x, not elementwise)
383384
else:
385+
if elementwise:
386+
raise NotImplementedError(
387+
'Elementwise scaling does not support key_vocabulary_filename')
384388
minus_min_max_for_key = tf_utils.apply_per_key_vocabulary(
385389
key_values, key, target_ndims=x.get_shape().ndims)
386-
min_x_value, max_x_value = (
387-
-minus_min_max_for_key[:, 0], minus_min_max_for_key[:, 1])
390+
min_x_value, max_x_value = (-minus_min_max_for_key[:, 0],
391+
minus_min_max_for_key[:, 1])
388392

389393
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
390394
x_values = tf_utils.get_values(x)

0 commit comments

Comments
 (0)