-
Notifications
You must be signed in to change notification settings - Fork 62
Multi-Tensor Input in Servo-Beam #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 37 commits
1d03b5a
1d55301
c39a82d
ace3f73
8287878
fedf60d
56357a0
ee6e928
622fbcf
7562e15
037f3b6
86b94c0
7498a05
44af058
3b64e74
f874ce3
300d8c9
a45716b
c42ce34
4e8651c
89878b3
d7e31fc
1a58bbf
0cdf874
c7e2237
42fab7e
1451037
100411b
353604a
5d40e92
f172fe3
668ed88
daf394e
9bc26b4
25b8631
8eb0797
b2e6689
dc9c513
1a12c5c
2fa6720
8c279ce
ff40846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
_RECORDBATCH_COLUMN = '__RAW_RECORD__' | ||
_KERAS_INPUT_SUFFIX = '_input' | ||
|
||
class DataType(object): | ||
EXAMPLE = 'EXAMPLE' | ||
SEQUENCEEXAMPLE = 'SEQUENCEEXAMPLE' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""TFX-BSL util""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
# Standard __future__ imports | ||
from __future__ import print_function | ||
|
||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pandas as pd | ||
import base64 | ||
import json | ||
import typing | ||
from typing import Dict, List, Text, Any, Set, Optional | ||
from tfx_bsl.beam.bsl_constants import _RECORDBATCH_COLUMN | ||
from tfx_bsl.beam.bsl_constants import _KERAS_INPUT_SUFFIX | ||
|
||
|
||
def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: | ||
serialized_examples = None | ||
for column_name, column_array in zip(elements.schema.names, elements.columns): | ||
if column_name == _RECORDBATCH_COLUMN: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should _RECORDBATCH_COLUMN be passed an an argument to the API? If we use a constant here, it would mean users would have to use this same constant when creating the TFXIO. |
||
column_type = column_array.flatten().type | ||
if not (pa.types.is_binary(column_type) or pa.types.is_string(column_type)): | ||
raise ValueError( | ||
'Expected a list of serialized examples in bytes or as a string, got %s' % | ||
type(example)) | ||
serialized_examples = column_array.flatten().to_pylist() | ||
break | ||
|
||
if (serialized_examples is None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if not serialized_examples: |
||
raise ValueError('Raw examples not found.') | ||
|
||
return serialized_examples | ||
|
||
|
||
def RecordToJSON(record_batch: pa.RecordBatch, prepare_instances_serialized) -> List[Text]: | ||
"""Returns a JSON string translated from `record_batch`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it return a JSON string, or a list of json dictionaries? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. list of json dictionaries, I changed the explanation below, must've missed this comment |
||
|
||
The conversion will take in a recordbatch that contains features from a | ||
tf.train.Example and will return a list of dict like string (JSON) where | ||
each item is a JSON representation of an example. | ||
- return format: [{ feature1: value1, ... }, ...] | ||
meixinzhang marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [{ feature1: value1, feature2: [value2_1, value2_2]... }, ...] ? Also you can put it into the Returns section in the comment. |
||
|
||
Args: | ||
record_batch: input RecordBatch. | ||
""" | ||
|
||
# TODO (b/155912552): Handle this for sequence example. | ||
|
||
def flatten(element: List[Any]): | ||
if len(element) == 1: | ||
return element[0] | ||
return element | ||
|
||
df = record_batch.to_pandas() | ||
if prepare_instances_serialized: | ||
return [{'b64': base64.b64encode(value).decode()} for value in df[_RECORDBATCH_COLUMN]] | ||
else: | ||
as_binary = df.columns.str.endswith("_bytes") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the name end with "_bytes"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User specified byte columns, it's consistent with the original implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required by cloud ai platform to indicate the bytes feature with '_bytes' suffix. |
||
df.loc[:, as_binary] = df.loc[:, as_binary].applymap( | ||
lambda feature: [{'b64': base64.b64encode(value).decode()} for value in feature]) | ||
|
||
if _RECORDBATCH_COLUMN in df.columns: | ||
df = df.drop(labels=_RECORDBATCH_COLUMN, axis=1) | ||
df = df.applymap(lambda x: flatten(x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use lambda for flatten? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an incentive for that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the original implementation uses: values = (values[0] if len(values) == 1 else values) |
||
return json.loads(df.to_json(orient='records')) | ||
|
||
|
||
def _find_input_name_in_features(features: Set[Text], | ||
input_name: Text) -> Optional[Text]: | ||
"""Maps input name to an entry in features. Returns None if not found.""" | ||
if input_name in features: | ||
return input_name | ||
# Some keras models prepend '_input' to the names of the inputs | ||
# so try under '<name>_input' as well. | ||
elif (input_name.endswith(_KERAS_INPUT_SUFFIX) and | ||
input_name[:-len(_KERAS_INPUT_SUFFIX)] in features): | ||
return input_name[:-len(_KERAS_INPUT_SUFFIX)] | ||
return None | ||
|
||
|
||
def filter_tensors_by_input_names( | ||
meixinzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tensors: Dict[Text, Any], | ||
input_names: List[Text]) -> Optional[Dict[Text, Any]]: | ||
"""Filter tensors by input names. | ||
In case we don't find the specified input name in the tensors and there | ||
exists only one input name, we assume we are feeding serialized examples to | ||
the model and return None. | ||
Args: | ||
tensors: Dict of tensors. | ||
input_names: List of input names. | ||
Returns: | ||
Filtered tensors. | ||
Raises: | ||
RuntimeError: When the specified input tensor cannot be found. | ||
""" | ||
|
||
if not input_names: | ||
return None | ||
result = {} | ||
tensor_keys = set(tensors.keys()) | ||
meixinzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# The case where the model takes serialized examples as input. | ||
if len(input_names) == 1 and _find_input_name_in_features(tensor_keys, input_names[0]): | ||
return None | ||
|
||
for name in input_names: | ||
tensor_name = _find_input_name_in_features(tensor_keys, name) | ||
if tensor_name is None: | ||
raise RuntimeError( | ||
'Input tensor not found: {}. Existing keys: {}.'.format( | ||
name, ','.join(tensors.keys()))) | ||
result[name] = tensors[tensor_name] | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExtractSerializedExamplesFromRecordBatch