Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
696a366
add interactive cli for python converter
pyu10055 Apr 17, 2019
96fe586
Merge branch 'master' into interactive-cli
pyu10055 Jun 24, 2019
f0f6989
add auto looking up the saved model tags and signatures
pyu10055 Jun 25, 2019
e944f24
fixed pylint errors and added tests
pyu10055 Jun 28, 2019
7e93a4f
adding more docstrings
pyu10055 Jun 29, 2019
c9a63dc
fix typo
pyu10055 Jul 1, 2019
3d9d5ac
merged master
pyu10055 Jul 1, 2019
82a0648
update the cli workflow according to the design doc comments
pyu10055 Jul 3, 2019
91f781e
fix pylint
pyu10055 Jul 3, 2019
ea5ed07
show the dtype string instead of value
pyu10055 Jul 3, 2019
093c8aa
fix pylint error
pyu10055 Jul 3, 2019
3431a07
move to questionary pip to support prompt_toolkit 2
pyu10055 Jul 17, 2019
1d6382d
Merge branch 'master' into interactive-cli
pyu10055 Jul 29, 2019
0c7d70b
update the README and fixed the tests
pyu10055 Jul 29, 2019
9acd99a
addressed the comments and add dryrun arg to generate the raw convert…
pyu10055 Jul 30, 2019
d24ad9b
fixed bugs
pyu10055 Jul 30, 2019
a4ea88f
address comments
pyu10055 Jul 30, 2019
4512586
use tuple instead of list
pyu10055 Jul 30, 2019
066af81
address the comments
pyu10055 Jul 31, 2019
0cc44e6
update compression choices
pyu10055 Aug 1, 2019
e2135a4
fix pylint error
pyu10055 Aug 1, 2019
4cdcda8
more pylint error
pyu10055 Aug 1, 2019
c245c3a
Merge branch 'master' into interactive-cli
pyu10055 Aug 1, 2019
89d52a1
used listdir to support py2
pyu10055 Aug 2, 2019
4995679
update text
pyu10055 Aug 2, 2019
725ddd2
fix pylint error
pyu10055 Aug 6, 2019
a0ce222
Merge branch 'master' into interactive-cli
pyu10055 Aug 6, 2019
0193b61
addressed comments
pyu10055 Aug 6, 2019
bd76a7f
Merge branch 'master' into interactive-cli
pyu10055 Aug 6, 2019
fe61a3d
Merge branch 'interactive-cli' of github.com:tensorflow/web into inte…
pyu10055 Aug 6, 2019
d4bc566
Merge branch 'master' into interactive-cli
pyu10055 Aug 14, 2019
fc509d2
Merge branch 'master' into interactive-cli
pyu10055 Aug 14, 2019
9c1d0e4
fixed test error
pyu10055 Aug 14, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ numpy==1.16.4
six==1.11.0
tensorflow==1.14.0
tensorflow-hub==0.5.0
PyInquirer==1.0.3
2 changes: 2 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _get_requirements(file):

CONSOLE_SCRIPTS = [
'tensorflowjs_converter = tensorflowjs.converters.converter:main',
'tensorflowjs_cli = tensorflowjs.cli:main',
]

setuptools.setup(
Expand Down Expand Up @@ -54,6 +55,7 @@ def _get_requirements(file):
],
py_modules=[
'tensorflowjs',
'tensorflowjs.cli',
'tensorflowjs.version',
'tensorflowjs.quantization',
'tensorflowjs.read_weights',
Expand Down
22 changes: 22 additions & 0 deletions python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ py_test(
],
)

py_test(
name = "cli_test",
srcs = ["cli_test.py"],
srcs_version = "PY2AND3",
deps = [
":expect_numpy_installed",
":cli",
],
)

py_binary(
name = "cli",
srcs = ["cli.py"],
srcs_version = "PY2AND3",
deps = [
":converters/converter",
"//tensorflowjs:expect_h5py_installed",
"//tensorflowjs:expect_keras_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

# A filegroup BUILD target that includes all the op list json files in the
# the op_list/ folder. The op_list folder itself is a symbolic link to the
# actual op_list folder under src/.
Expand Down
1 change: 1 addition & 0 deletions python/tensorflowjs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from tensorflowjs import converters
from tensorflowjs import quantization
from tensorflowjs import version
from tensorflowjs import cli

__version__ = version.version
291 changes: 291 additions & 0 deletions python/tensorflowjs/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Interactive command line tool for tensorflow.js model conversion."""

from __future__ import print_function, unicode_literals

import os
import re

from PyInquirer import prompt
from examples import custom_style_3
from tensorflow.python.saved_model.loader_impl import parse_saved_model

# regex for recognizing valid url for TFHub module.
URL_REGEX = re.compile(
# http:// or https://
r'^(?:http)s?://', re.IGNORECASE)


def quantization_type(value):
"""Determine the quantization type based on user selection.
Args:
value: user selected value.
"""
answer = None
try:
if '1/2' in value:
answer = 2
elif '1/4' in value:
answer = 1
except ValueError:
answer = None
return answer


def of_values(answers, key, values):
"""Determine user's answer for the key is in the value list.
Args:
answer: Dict of user's answers to the questions.
key: question key.
values: List of values to check from.
"""
try:
value = answers[key]
return value in values
except KeyError:
return False


def input_path_message(answers):
"""Determine question for model's input path.
Args:
answer: Dict of user's answers to the questions
"""
answer = answers['input_format']
if answer == 'keras':
return 'What is the path of input HDF5 file?'
elif answer == 'tf_hub':
return 'What is the TFHub module URL?'
else:
return 'What is the directory that contains the model?'


def validate_input_path(value, input_format):
"""validate the input path for given input format.
Args:
value: input path of the model.
input_format: model format string.
"""
value = os.path.expanduser(value)
if input_format == 'tf_hub':
if re.match(URL_REGEX, value) is None:
return 'This is not an valid url for TFHub module: %s' % value
elif not os.path.exists(value):
return 'Nonexistent path for the model: %s' % value
if input_format in ['keras_saved_model', 'tf_saved_model']:
if not os.path.isdir(value):
return 'The path provided is not a directory: %s' % value
if not any(fname.endswith('.pb') for fname in os.listdir(value)):
return 'This is an invalid saved model directory: %s' % value
if input_format == 'tfjs_layers_model':
if not os.path.isfile(value):
return 'The path provided is not a file: %s' % value
if input_format == 'keras':
if not os.path.isfile(value):
return 'The path provided is not a file: %s' % value
return True


def validate_output_path(value):
"""validate the input path for given input format.
Args:
value: input path of the model.
input_format: model format string.
"""
value = os.path.expanduser(value)
if os.path.exists(value):
return 'The output path already exists: %s' % value
return True


def generate_command(params):
"""generate the tensorflowjs command string for the selected params.
Args:
params: user selected parameters for the conversion.
"""
args = 'tensorflowjs_converter'
not_param_list = ['input_path', 'output_path']
no_false_param = ['split_weights_by_layer', 'skip_op_check']
for key, value in sorted(params.items()):
if key not in not_param_list and value is not None:
if key in no_false_param:
if value is True:
args += ' --%s' % (key)
else:
args += ' --%s=%s' % (key, value)

args += ' %s %s' % (params['input_path'], params['output_path'])
return args

def is_saved_model(answers):
"""check if the input path contains saved model.
Args:
params: user selected parameters for the conversion.
"""
return answers['input_format'] == 'tf_saved_model' or \
answers['input_format'] == 'keras_saved_model' and \
answers['output_format'] == 'tfjs_graph_model'

def available_output_formats(answers):
"""generate the output formats for given input format.
Args:
ansowers: user selected parameter dict.
"""
input_format = answers['input_format']
if input_format == 'keras_saved_model':
return ['tfjs_graph_model', 'tfjs_layers_model']
if input_format == 'tfjs_layers_model':
return ['keras', 'tfjs_graph_model']
return []


def available_tags(answers):
"""generate the available saved model tags from the proto file.
Args:
ansowers: user selected parameter dict.
"""
if is_saved_model(answers):
saved_model = parse_saved_model(answers['input_path'])
tags = []
for meta_graph in saved_model.meta_graphs:
tags.append(",".join(meta_graph.meta_info_def.tags))
return tags
return []


def available_signature_names(answers):
"""generate the available saved model signatures from the proto file
and selected tags.
Args:
ansowers: user selected parameter dict.
"""
if is_saved_model(answers):
path = answers['input_path']
tags = answers['saved_model_tags']
saved_model = parse_saved_model(path)
for meta_graph in saved_model.meta_graphs:
if tags == ",".join(meta_graph.meta_info_def.tags):
return meta_graph.signature_def.keys()
return []


def main():
print('Weclome to TensorFlow.js converter.')

formats = [
{
'type': 'list',
'name': 'input_format',
'message': 'What is your input format?',
'choices': ['keras', 'keras_saved_model',
'tf_saved_model', 'tf_hub', 'tfjs_layers_model']
},
{
'type': 'list',
'name': 'output_format',
'message': 'What is your output format?',
'choices': available_output_formats,
'when': lambda answers: of_values(answers, 'input_format',
['keras_saved_model',
'tfjs_layers_model'])
}
]

options = prompt(formats, style=custom_style_3)
message = input_path_message(options)

questions = [
{
'type': 'input',
'name': 'input_path',
'message': message,
'filter': os.path.expanduser,
'validate': lambda value: validate_input_path(
value, options['input_format'])
},
{
'type': 'list',
'name': 'saved_model_tags',
'choices': available_tags,
'message': 'What is tags for the saved model?',
'when': is_saved_model
},
{
'type': 'list',
'name': 'signature_name',
'message': 'What is signature name of the model?',
'choices': available_signature_names,
'when': is_saved_model
},
{
'type': 'list',
'name': 'quantization_bytes',
'message': 'Do you want to compress the model? '
'(this will decrease the model precision.)',
'choices': ['No compression',
'compress weights to 1/2 the size',
'compress weights to 1/4 the size'],
'filter': quantization_type
},
{
'type': 'input',
'name': 'weight_shard_size_byte',
'message': 'Please enter shard size (in bytes) of the weight files?',
'default': str(4 * 1024 * 1024),
'when': lambda answers: of_values(answers, 'output_format',
['tfjs_layers_model'])
},
{
'type': 'confirm',
'name': 'split_weights_by_layer',
'message': 'Do you want to split weights by layers?',
'default': False,
'when': lambda answers: of_values(answers, 'input_format',
['tfjs_layers_model'])
},
{
'type': 'confirm',
'name': 'skip_op_check',
'message': 'Do you want to skip op validation?',
'default': False,
'when': lambda answers: of_values(answers, 'input_format',
['tf_saved_model', 'tf_hub'])
},
{
'type': 'confirm',
'name': 'strip_debug_ops',
'message': 'Do you want to strip debug ops?',
'default': True,
'when': lambda answers: of_values(answers, 'input_format',
['tf_saved_model', 'tf_hub'])
},
{
'type': 'input',
'name': 'output_path',
'message': 'Which directory do you want save the converted model?',
'filter': os.path.expanduser,
'validate': validate_output_path
}
]
params = prompt(questions, options, style=custom_style_3)

command = generate_command(params)
print(command)
os.system(command)


if __name__ == '__main__':
main()
Loading