-
Notifications
You must be signed in to change notification settings - Fork 62
Multi-Tensor Input in Servo-Beam #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1d03b5a
1d55301
c39a82d
ace3f73
8287878
fedf60d
56357a0
ee6e928
622fbcf
7562e15
037f3b6
86b94c0
7498a05
44af058
3b64e74
f874ce3
300d8c9
a45716b
c42ce34
4e8651c
89878b3
d7e31fc
1a58bbf
0cdf874
c7e2237
42fab7e
1451037
100411b
353604a
5d40e92
f172fe3
668ed88
daf394e
9bc26b4
25b8631
8eb0797
b2e6689
dc9c513
1a12c5c
2fa6720
8c279ce
ff40846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_RECORDBATCH_COLUMN = '__RAW_RECORD__' | ||
|
||
class DataType(object): | ||
EXAMPLE = 'EXAMPLE' | ||
SEQUENCEEXAMPLE = 'SEQUENCEEXAMPLE' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""TFX-BSL util""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
# Standard __future__ imports | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pandas as pd | ||
import base64 | ||
import json | ||
import typing | ||
from typing import Dict, List, Text, Any, Set, Mapping, Optional | ||
from tfx_bsl.beam.bsl_constants import _RECORDBATCH_COLUMN | ||
|
||
_KERAS_INPUT_SUFFIX = '_input' | ||
|
||
def ExtractSerializedExamplesFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: | ||
serialized_examples = None | ||
for column_name, column_array in zip(elements.schema.names, elements.columns): | ||
if column_name == _RECORDBATCH_COLUMN: | ||
column_type = column_array.flatten().type | ||
if not (pa.types.is_binary(column_type) or pa.types.is_string(column_type)): | ||
raise ValueError( | ||
'Expected a list of serialized examples in bytes or as a string, got %s' % | ||
type(example)) | ||
serialized_examples = column_array.flatten().to_pylist() | ||
break | ||
|
||
if not serialized_examples: | ||
raise ValueError('Raw examples not found.') | ||
|
||
return serialized_examples | ||
|
||
|
||
def RecordToJSON( | ||
record_batch: pa.RecordBatch, prepare_instances_serialized) -> List[Mapping[Text, Any]]: | ||
"""Returns a list of JSON dictionaries translated from `record_batch`. | ||
|
||
The conversion will take in a recordbatch that contains features from a | ||
tf.train.Example and will return a list of dict like string (JSON) where | ||
each item is a JSON representation of an example. | ||
|
||
Return: | ||
List of JSON dictionaries | ||
- format: [{ feature1: value1, feature2: [value2_1, value2_2]... }, ...] | ||
|
||
Args: | ||
record_batch: input RecordBatch. | ||
""" | ||
|
||
# TODO (b/155912552): Handle this for sequence example. | ||
df = record_batch.to_pandas() | ||
if prepare_instances_serialized: | ||
return [{'b64': base64.b64encode(value).decode()} for value in df[_RECORDBATCH_COLUMN]] | ||
else: | ||
as_binary = df.columns.str.endswith("_bytes") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the name end with "_bytes"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User specified byte columns, it's consistent with the original implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required by cloud ai platform to indicate the bytes feature with '_bytes' suffix. |
||
df.loc[:, as_binary] = df.loc[:, as_binary].applymap( | ||
lambda feature: [{'b64': base64.b64encode(value).decode()} for value in feature]) | ||
|
||
if _RECORDBATCH_COLUMN in df.columns: | ||
df = df.drop(labels=_RECORDBATCH_COLUMN, axis=1) | ||
df = df.applymap(lambda values: values[0] if len(values) == 1 else values) | ||
return json.loads(df.to_json(orient='records')) | ||
|
||
|
||
# TODO: Reuse these functions in TFMA. | ||
def _find_input_name_in_features(features: Set[Text], | ||
input_name: Text) -> Optional[Text]: | ||
"""Maps input name to an entry in features. Returns None if not found.""" | ||
if input_name in features: | ||
return input_name | ||
# Some keras models prepend '_input' to the names of the inputs | ||
# so try under '<name>_input' as well. | ||
elif (input_name.endswith(_KERAS_INPUT_SUFFIX) and | ||
input_name[:-len(_KERAS_INPUT_SUFFIX)] in features): | ||
return input_name[:-len(_KERAS_INPUT_SUFFIX)] | ||
return None | ||
|
||
|
||
def filter_tensors_by_input_names( | ||
meixinzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tensors: Dict[Text, Any], | ||
input_names: List[Text]) -> Optional[Dict[Text, Any]]: | ||
"""Filter tensors by input names. | ||
In case we don't find the specified input name in the tensors and there | ||
exists only one input name, we assume we are feeding serialized examples to | ||
the model and return None. | ||
Args: | ||
tensors: Dict of tensors. | ||
input_names: List of input names. | ||
Returns: | ||
Filtered tensors. | ||
Raises: | ||
RuntimeError: When the specified input tensor cannot be found. | ||
""" | ||
|
||
if not input_names: | ||
return None | ||
result = {} | ||
tensor_keys = set(tensors.keys()) | ||
meixinzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# The case where the model takes serialized examples as input. | ||
if len(input_names) == 1 and _find_input_name_in_features(tensor_keys, input_names[0]): | ||
return None | ||
|
||
for name in input_names: | ||
tensor_name = _find_input_name_in_features(tensor_keys, name) | ||
if tensor_name is None: | ||
raise RuntimeError( | ||
'Input tensor not found: {}. Existing keys: {}.'.format( | ||
name, ','.join(tensors.keys()))) | ||
result[name] = tensors[tensor_name] | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2019 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for tfx_bsl.bsl_util.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
# Standard __future__ imports | ||
from __future__ import print_function | ||
|
||
import base64 | ||
import json | ||
import os | ||
try: | ||
import unittest.mock as mock | ||
except ImportError: | ||
import mock | ||
|
||
import apache_beam as beam | ||
import pyarrow as pa | ||
import tensorflow as tf | ||
from google.protobuf import text_format | ||
from tfx_bsl.beam import bsl_util | ||
from tfx_bsl.beam.bsl_constants import _RECORDBATCH_COLUMN | ||
|
||
|
||
class TestBslUtil(tf.test.TestCase): | ||
def test_request_body_with_binary_data(self): | ||
record_batch_remote = pa.RecordBatch.from_arrays( | ||
[ | ||
pa.array([["ASa8asdf", "ASa8asdf"]], type=pa.list_(pa.binary())), | ||
pa.array([["JLK7ljk3"]], type=pa.list_(pa.utf8())), | ||
pa.array([[1, 2]], type=pa.list_(pa.int32())), | ||
pa.array([[4.5, 5, 5.5]], type=pa.list_(pa.float32())) | ||
], | ||
['x_bytes', 'x', 'y', 'z'] | ||
) | ||
|
||
result = list(bsl_util.RecordToJSON(record_batch_remote, False)) | ||
self.assertEqual([ | ||
{ | ||
'x_bytes': [ | ||
{'b64': 'QVNhOGFzZGY='}, | ||
{'b64': 'QVNhOGFzZGY='} | ||
], | ||
'x': 'JLK7ljk3', | ||
'y': [1, 2], | ||
'z': [4.5, 5, 5.5] | ||
}, | ||
], result) | ||
|
||
def test_request_serialized_example(self): | ||
example = text_format.Parse( | ||
""" | ||
features { | ||
feature { key: "x_bytes" value { bytes_list { value: ["ASa8asdf"] }}} | ||
feature { key: "x" value { bytes_list { value: "JLK7ljk3" }}} | ||
feature { key: "y" value { int64_list { value: [1, 2] }}} | ||
} | ||
""", tf.train.Example()) | ||
|
||
serialized_example_remote = [example.SerializeToString()] | ||
record_batch_remote = pa.RecordBatch.from_arrays( | ||
[ | ||
pa.array([["ASa8asdf"]], type=pa.list_(pa.binary())), | ||
pa.array([["JLK7ljk3"]], type=pa.list_(pa.utf8())), | ||
pa.array([[1, 2]], type=pa.list_(pa.int32())), | ||
pa.array([[4.5, 5, 5.5]], type=pa.list_(pa.float32())), | ||
serialized_example_remote | ||
], | ||
['x_bytes', 'x', 'y', 'z', _RECORDBATCH_COLUMN] | ||
) | ||
|
||
result = list(bsl_util.RecordToJSON(record_batch_remote, True)) | ||
self.assertEqual(result, [{ | ||
'b64': base64.b64encode(example.SerializeToString()).decode() | ||
}]) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should _RECORDBATCH_COLUMN be passed an an argument to the API?
If we use a constant here, it would mean users would have to use this same constant when creating the TFXIO.