Skip to content

Commit ff4ddd4

Browse files
lintian06copybara-github
authored andcommitted
Add recommendation model spec to model maker.
PiperOrigin-RevId: 347345279
1 parent 7b2850c commit ff4ddd4

File tree

5 files changed

+192
-28
lines changed

5 files changed

+192
-28
lines changed

lite/examples/recommendation/ml/model/recommendation_model_launcher_keras.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,40 @@
3030

3131
FLAGS = flags.FLAGS
3232

33-
flags.DEFINE_string('training_data_filepattern', None,
34-
'File pattern of the training data.')
35-
flags.DEFINE_string('testing_data_filepattern', None,
36-
'File pattern of the training data.')
37-
flags.DEFINE_string('model_dir', None, 'Directory to store checkpoints.')
38-
flags.DEFINE_string(
39-
'params_path', None,
40-
'Path to the json file containing params needed to run '
41-
'p13n recommendation model.')
42-
flags.DEFINE_integer('batch_size', 1, 'Training batch size.')
43-
flags.DEFINE_float('learning_rate', 0.1, 'Learning rate.')
44-
flags.DEFINE_integer('steps_per_epoch', 10,
45-
'Number of steps to run in each epoch.')
46-
flags.DEFINE_integer('num_epochs', 10000, 'Number of training epochs.')
47-
flags.DEFINE_integer('num_eval_steps', 1000, 'Number of eval steps.')
48-
flags.DEFINE_enum('run_mode', 'train_and_eval', ['train_and_eval', 'export'],
49-
'Mode of the launcher, default value is: train_and_eval')
50-
flags.DEFINE_float('gradient_clip_norm', 1.0,
51-
'gradient_clip_norm <= 0 meaning no clip.')
52-
flags.DEFINE_integer('max_history_length', 10, 'Max length of user history.')
53-
flags.DEFINE_integer('num_predictions', 100,
54-
'Num of top predictions to output.')
55-
flags.DEFINE_string(
56-
'encoder_type', 'bow', 'Type of the encoder for context'
57-
'encoding, the value could be ["bow", "rnn", "cnn"].')
58-
flags.DEFINE_string('checkpoint_path', '', 'Path to the checkpoint.')
59-
6033
CONTEXT = 'context'
6134
LABEL = 'label'
6235

6336

37+
def define_flags():
38+
"""Define flags."""
39+
flags.DEFINE_string('training_data_filepattern', None,
40+
'File pattern of the training data.')
41+
flags.DEFINE_string('testing_data_filepattern', None,
42+
'File pattern of the training data.')
43+
flags.DEFINE_string('model_dir', None, 'Directory to store checkpoints.')
44+
flags.DEFINE_string(
45+
'params_path', None,
46+
'Path to the json file containing params needed to run '
47+
'p13n recommendation model.')
48+
flags.DEFINE_integer('batch_size', 1, 'Training batch size.')
49+
flags.DEFINE_float('learning_rate', 0.1, 'Learning rate.')
50+
flags.DEFINE_integer('steps_per_epoch', 10,
51+
'Number of steps to run in each epoch.')
52+
flags.DEFINE_integer('num_epochs', 10000, 'Number of training epochs.')
53+
flags.DEFINE_integer('num_eval_steps', 1000, 'Number of eval steps.')
54+
flags.DEFINE_enum('run_mode', 'train_and_eval', ['train_and_eval', 'export'],
55+
'Mode of the launcher, default value is: train_and_eval')
56+
flags.DEFINE_float('gradient_clip_norm', 1.0,
57+
'gradient_clip_norm <= 0 meaning no clip.')
58+
flags.DEFINE_integer('max_history_length', 10, 'Max length of user history.')
59+
flags.DEFINE_integer('num_predictions', 100,
60+
'Num of top predictions to output.')
61+
flags.DEFINE_string(
62+
'encoder_type', 'bow', 'Type of the encoder for context'
63+
'encoding, the value could be ["bow", "rnn", "cnn"].')
64+
flags.DEFINE_string('checkpoint_path', '', 'Path to the checkpoint.')
65+
66+
6467
class SimpleCheckpoint(tf.keras.callbacks.Callback):
6568
"""Keras callback to save tf.train.Checkpoints."""
6669

@@ -250,4 +253,5 @@ def main(_):
250253

251254

252255
if __name__ == '__main__':
256+
define_flags()
253257
app.run(main)

tensorflow_examples/lite/model_maker/core/task/model_spec/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import inspect
1717

18+
1819
from tensorflow_examples.lite.model_maker.core.task.model_spec import audio_spec
20+
from tensorflow_examples.lite.model_maker.core.task.model_spec import recommendation_spec
1921
from tensorflow_examples.lite.model_maker.core.task.model_spec.image_spec import efficientnet_lite0_spec
2022
from tensorflow_examples.lite.model_maker.core.task.model_spec.image_spec import efficientnet_lite1_spec
2123
from tensorflow_examples.lite.model_maker.core.task.model_spec.image_spec import efficientnet_lite2_spec
@@ -61,6 +63,11 @@
6163

6264
# Audio classification
6365
'audio_browser_fft': audio_spec.BrowserFFTSpec,
66+
67+
# Recommendation
68+
'recommendation_bow': recommendation_spec.recommendation_bow_spec,
69+
'recommendation_cnn': recommendation_spec.recommendation_cnn_spec,
70+
'recommendation_rnn': recommendation_spec.recommendation_rnn_spec,
6471
}
6572

6673
# List constants for supported models.
@@ -73,6 +80,11 @@
7380
]
7481
QUESTION_ANSWERING_MODELS = ['bert_qa', 'mobilebert_qa', 'mobilebert_qa_squad']
7582
AUDIO_CLASSIFICATION_MODELS = ['audio_browser_fft']
83+
RECOMMENDATION_MODELS = [
84+
'recommendation_bow',
85+
'recommendation_rnn',
86+
'recommendation_cnn',
87+
]
7688

7789

7890
def get(spec_or_str):

tensorflow_examples/lite/model_maker/core/task/model_spec/model_spec_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@
1818

1919
import os
2020

21+
from absl.testing import parameterized
2122
import tensorflow.compat.v2 as tf
2223

2324
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
2425

26+
MODELS = (
27+
ms.IMAGE_CLASSIFICATION_MODELS + ms.TEXT_CLASSIFICATION_MODELS +
28+
ms.QUESTION_ANSWERING_MODELS + ms.AUDIO_CLASSIFICATION_MODELS +
29+
ms.RECOMMENDATION_MODELS)
2530

26-
class ModelSpecTest(tf.test.TestCase):
31+
32+
class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
2733

2834
def test_get(self):
2935
spec = ms.get('mobilenet_v2')
@@ -35,6 +41,12 @@ def test_get(self):
3541
spec = ms.get(ms.mobilenet_v2_spec)
3642
self.assertIsInstance(spec, ms.ImageModelSpec)
3743

44+
@parameterized.parameters(MODELS)
45+
def test_get_not_none(self, model):
46+
spec = ms.get(model)
47+
self.assertIsNotNone(spec)
48+
49+
def test_get_raises(self):
3850
with self.assertRaises(KeyError):
3951
ms.get('not_exist_model_spec')
4052

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Recommendation model specification."""
15+
16+
import functools
17+
18+
import tensorflow as tf # pylint: disable=unused-import
19+
20+
HAS_RECOMMENDATION = True
21+
try:
22+
from tensorflow_examples.lite.examples.recommendation.ml.model import recommendation_model as rm # pylint: disable=g-import-not-at-top
23+
except ImportError:
24+
HAS_RECOMMENDATION = False
25+
26+
27+
class RecommendationSpec(object):
28+
"""Recommendation model spec."""
29+
30+
def __init__(self,
31+
encoder_type='bow',
32+
context_embedding_dim=128,
33+
label_embedding_dim=32,
34+
item_vocab_size=16,
35+
num_predictions=10,
36+
hidden_layer_dim_ratios=None,
37+
conv_num_filter_ratios=None,
38+
conv_kernel_size=None,
39+
lstm_num_units=None):
40+
"""Initialize spec.
41+
42+
Args:
43+
encoder_type: str, encoder type. One of ('bow', 'cnn', 'rnn').
44+
context_embedding_dim: int, dimension of context embedding layer.
45+
label_embedding_dim: int, dimension of label embedding layer.
46+
item_vocab_size: int, the size of items to be predict.
47+
num_predictions: int, the number of top-K predictions in the output.
48+
hidden_layer_dim_ratios: list of float, number of units in hidden layers
49+
specified by ratios. default: [1.0, 0.5, 0.25].
50+
conv_num_filter_ratios: list of int, for 'cnn', Conv1D layers' filter
51+
ratios based on context_embedding_dim.
52+
conv_kernel_size: int, for 'rnn', Conv1D layers' kernel size.
53+
lstm_num_units: int, for 'rnn', LSTM layer's unit number.
54+
"""
55+
hidden_layer_dim_ratios = hidden_layer_dim_ratios or [1.0, 0.5, 0.25]
56+
57+
if encoder_type == 'cnn':
58+
conv_num_filter_ratios = conv_num_filter_ratios or [2, 4]
59+
conv_kernel_size = conv_kernel_size or 4
60+
elif encoder_type == 'rnn':
61+
lstm_num_units = lstm_num_units or 16
62+
63+
self.encoder_type = encoder_type
64+
self.context_embedding_dim = context_embedding_dim
65+
self.label_embedding_dim = label_embedding_dim
66+
self.hidden_layer_dim_ratios = hidden_layer_dim_ratios
67+
self.item_vocab_size = item_vocab_size
68+
self.num_predictions = num_predictions
69+
self.conv_num_filter_ratios = conv_num_filter_ratios
70+
self.conv_kernel_size = conv_kernel_size
71+
self.lstm_num_units = lstm_num_units
72+
73+
self._params = {
74+
'encoder_type': encoder_type,
75+
'context_embedding_dim': context_embedding_dim,
76+
'label_embedding_dim': label_embedding_dim,
77+
'hidden_layer_dim_ratios': hidden_layer_dim_ratios,
78+
'item_vocab_size': item_vocab_size,
79+
'num_predictions': num_predictions,
80+
'conv_num_filter_ratios': conv_num_filter_ratios,
81+
'conv_kernel_size': conv_kernel_size,
82+
'lstm_num_units': lstm_num_units,
83+
}
84+
85+
def create_model(self):
86+
"""Creates recommendation model based on params."""
87+
if not HAS_RECOMMENDATION:
88+
return None
89+
return rm.RecommendationModel(self._params)
90+
91+
92+
recommendation_bow_spec = functools.partial(
93+
RecommendationSpec, encoder_type='bow')
94+
recommendation_cnn_spec = functools.partial(
95+
RecommendationSpec, encoder_type='cnn')
96+
recommendation_rnn_spec = functools.partial(
97+
RecommendationSpec, encoder_type='rnn')
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for recommendation spec."""
15+
16+
from absl.testing import parameterized
17+
import tensorflow.compat.v2 as tf
18+
19+
from tensorflow_examples.lite.model_maker.core.task.model_spec import recommendation_spec
20+
21+
22+
class RecommendationSpecTest(tf.test.TestCase, parameterized.TestCase):
23+
24+
@parameterized.parameters(
25+
('bow'),
26+
('cnn'),
27+
('rnn'),
28+
)
29+
def test_create_recommendation_model(self, encoder_type):
30+
spec = recommendation_spec.RecommendationSpec(encoder_type)
31+
model = spec.create_model()
32+
if recommendation_spec.HAS_RECOMMENDATION:
33+
self.assertIsInstance(model, recommendation_spec.rm.RecommendationModel)
34+
else:
35+
self.assertIsNone(model)
36+
37+
38+
if __name__ == '__main__':
39+
tf.test.main()

0 commit comments

Comments
 (0)