Skip to content

Commit 6922595

Browse files
Lifannrhdong
authored andcommitted
Update demo for new eager mode API
1 parent 5f45184 commit 6922595

File tree

5 files changed

+139
-197
lines changed

5 files changed

+139
-197
lines changed

demo/dynamic_embedding/amazon-us-reviews-digital-video-games/video_game_model.py

Lines changed: 0 additions & 150 deletions
This file was deleted.

demo/dynamic_embedding/amazon-us-reviews-digital-video-games/README.md renamed to demo/dynamic_embedding/amazon-video-games-keras-eager/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ It will produce a model to `export_dir`.
1313

1414
## Inference:
1515
```bash
16-
python main.py --mode=test --export_dir="export" --batch_size=10
16+
python main.py --mode=test --export_dir="export" --batch_size=64
1717
```
18-
It will print accuracy to the prediction on verified purchase of the digital video games.
18+
It will print accuracy to the prediction on verified purchase of the digital video games.

demo/dynamic_embedding/amazon-us-reviews-digital-video-games/feature.py renamed to demo/dynamic_embedding/amazon-video-games-keras-eager/feature.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import tensorflow as tf
22
import tensorflow_datasets as tfds
3-
import sys
43

5-
ENCODDING_SEGMENT_LENGTH = 1000000
4+
ENCODING_SEGMENT_LENGTH = 1000000
65
NON_LETTER_OR_NUMBER_PATTERN = r'[^a-zA-Z0-9]'
76

87
FAETURES = [
@@ -12,6 +11,8 @@
1211
]
1312
LABEL = 'verified_purchase'
1413

14+
NUM_FEATURE_SLOTS = 0
15+
1516

1617
class _RawFeature(object):
1718
"""
@@ -22,13 +23,15 @@ def __init__(self, dtype, category):
2223
if not isinstance(category, int):
2324
raise TypeError('category must be an integer.')
2425
self.category = category
26+
global NUM_FEATURE_SLOTS
27+
NUM_FEATURE_SLOTS = max(NUM_FEATURE_SLOTS, self.category)
2528

2629
def encode(self, tensor):
2730
raise NotImplementedError
2831

2932
def match_category(self, tensor):
30-
min_code = self.category * ENCODDING_SEGMENT_LENGTH
31-
max_code = (self.category + 1) * ENCODDING_SEGMENT_LENGTH
33+
min_code = self.category * ENCODING_SEGMENT_LENGTH
34+
max_code = (self.category + 1) * ENCODING_SEGMENT_LENGTH
3235
mask = tf.math.logical_and(tf.greater_equal(tensor, min_code),
3336
tf.less(tensor, max_code))
3437
return mask
@@ -40,8 +43,8 @@ def __init__(self, dtype, category):
4043
super(_StringFeature, self).__init__(dtype, category)
4144

4245
def encode(self, tensor):
43-
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
44-
tensor += ENCODDING_SEGMENT_LENGTH * self.category
46+
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODING_SEGMENT_LENGTH)
47+
tensor += ENCODING_SEGMENT_LENGTH * self.category
4548
return tensor
4649

4750

@@ -53,8 +56,8 @@ def __init__(self, dtype, category):
5356
def encode(self, tensor):
5457
tensor = tf.strings.regex_replace(tensor, NON_LETTER_OR_NUMBER_PATTERN, ' ')
5558
tensor = tf.strings.split(tensor, sep=' ').to_tensor('')
56-
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
57-
tensor += ENCODDING_SEGMENT_LENGTH * self.category
59+
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODING_SEGMENT_LENGTH)
60+
tensor += ENCODING_SEGMENT_LENGTH * self.category
5861
return tensor
5962

6063

@@ -65,23 +68,23 @@ def __init__(self, dtype, category):
6568

6669
def encode(self, tensor):
6770
tensor = tf.as_string(tensor)
68-
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
69-
tensor += ENCODDING_SEGMENT_LENGTH * self.category
71+
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODING_SEGMENT_LENGTH)
72+
tensor += ENCODING_SEGMENT_LENGTH * self.category
7073
return tensor
7174

7275

7376
FEATURE_AND_ENCODER = {
74-
'customer_id': _StringFeature(tf.string, 1),
75-
'helpful_votes': _IntegerFeature(tf.int32, 2),
76-
'product_category': _StringFeature(tf.string, 3),
77-
'product_id': _StringFeature(tf.string, 4),
78-
'product_parent': _StringFeature(tf.string, 5),
79-
'product_title': _TextFeature(tf.string, 6),
80-
#'review_body': _TextFeature(tf.string, 7), # bad feature
81-
'review_headline': _TextFeature(tf.string, 8),
82-
'review_id': _StringFeature(tf.string, 9),
83-
'star_rating': _IntegerFeature(tf.int32, 10),
84-
'total_votes': _IntegerFeature(tf.int32, 11),
77+
'customer_id': _StringFeature(tf.string, 0),
78+
'helpful_votes': _IntegerFeature(tf.int32, 1),
79+
'product_category': _StringFeature(tf.string, 2),
80+
'product_id': _StringFeature(tf.string, 3),
81+
'product_parent': _StringFeature(tf.string, 4),
82+
'product_title': _TextFeature(tf.string, 5),
83+
'review_headline': _TextFeature(tf.string, 6),
84+
'review_id': _StringFeature(tf.string, 7),
85+
'star_rating': _IntegerFeature(tf.int32, 8),
86+
'total_votes': _IntegerFeature(tf.int32, 9),
87+
#'review_body': _TextFeature(tf.string, 10), # bad feature
8588
}
8689

8790

@@ -99,6 +102,12 @@ def encode_feature(data):
99102
return collected_features
100103

101104

105+
@tf.function
106+
def get_category(tensor):
107+
x = tf.math.floordiv(tensor, ENCODING_SEGMENT_LENGTH)
108+
return x
109+
110+
102111
def get_labels(data):
103112
return data['verified_purchase']
104113

demo/dynamic_embedding/amazon-us-reviews-digital-video-games/main.py renamed to demo/dynamic_embedding/amazon-video-games-keras-eager/main.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import feature
22
import video_game_model
33
import tensorflow as tf
4+
45
from tensorflow_recommenders_addons import dynamic_embedding as de
56

67
from absl import flags
@@ -11,10 +12,11 @@
1112
flags.DEFINE_integer('embedding_size', 4, 'Embedding size.')
1213
flags.DEFINE_integer('shuffle_size', 3000,
1314
'Shuffle pool size for input examples.')
14-
flags.DEFINE_integer('reserved_features', 30000,
15+
flags.DEFINE_integer('max_size', 100000,
1516
'Number of reserved features in embedding.')
1617
flags.DEFINE_string('export_dir', './export_dir', 'Directory to export model.')
1718
flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.')
19+
flags.DEFINE_string('save_format', 'keras', 'options: keras, tf')
1820

1921
FLAGS = flags.FLAGS
2022

@@ -27,6 +29,13 @@ def train(num_steps):
2729
# Create a model
2830
model = video_game_model.VideoGameDnn(batch_size=FLAGS.batch_size,
2931
embedding_size=FLAGS.embedding_size)
32+
optimizer = tf.keras.optimizers.Adam(1E-3, clipnorm=None)
33+
optimizer = de.DynamicEmbeddingOptimizer(optimizer)
34+
auc = tf.keras.metrics.AUC(num_thresholds=1000)
35+
accuracy = tf.keras.metrics.BinaryAccuracy(dtype=tf.float32)
36+
model.compile(optimizer=optimizer,
37+
loss='binary_crossentropy',
38+
metrics=[accuracy, auc])
3039

3140
# Get data iterator
3241
iterator = feature.initialize_dataset(batch_size=FLAGS.batch_size,
@@ -39,29 +48,31 @@ def train(num_steps):
3948
try:
4049
for step in range(num_steps):
4150
features, labels = feature.input_fn(iterator)
42-
loss, auc = model.train(features, labels)
43-
44-
# To avoid too many features burst the memory, we restrict
45-
# the model embedding layer to `reserved_features` features.
46-
# And the restriction behavior will be triggered when it gets
47-
# over `reserved_features * 1.2`.
48-
model.embedding_store.restrict(FLAGS.reserved_features,
49-
trigger=int(FLAGS.reserved_features * 1.2))
5051

5152
if step % 10 == 0:
52-
print('step: {}, loss: {}, var_size: {}, auc: {}'.format(
53-
step, loss, model.embedding_store.size(), auc))
53+
verbose = 1
54+
else:
55+
verbose = 0
56+
57+
model.fit(features, labels, steps_per_epoch=1, epochs=1, verbose=verbose)
58+
59+
if verbose > 0:
60+
print('step: {}, size of sparse domain: {}'.format(
61+
step, model.embedding_store.size()))
62+
model.embedding_store.restrict(int(FLAGS.max_size * 0.8),
63+
trigger=FLAGS.max_size)
5464

5565
except tf.errors.OutOfRangeError:
5666
print('Run out the training data.')
5767

58-
# Set TFRA ops become legit.
59-
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
60-
6168
# Save the model for inference.
62-
inference_model = video_game_model.VideoGameDnnInference(model)
63-
inference_model(feature.input_fn(iterator)[0])
64-
inference_model.save('export', signatures=None, options=options)
69+
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
70+
if FLAGS.save_format == 'tf':
71+
model.save(FLAGS.export_dir, options=options)
72+
elif FLAGS.save_format == 'keras':
73+
tf.keras.models.save_model(model, FLAGS.export_dir, options=options)
74+
else:
75+
raise NotImplemented
6576

6677

6778
def test(num_steps):
@@ -70,25 +81,36 @@ def test(num_steps):
7081
"""
7182

7283
# Load model.
73-
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
74-
model = tf.saved_model.load('export', tags='serve', options=options)
75-
sig = model.signatures['serving_default']
84+
options = tf.saved_model.LoadOptions()
85+
if FLAGS.save_format == 'tf':
86+
model = tf.saved_model.load(FLAGS.export_dir, tags='serve')
87+
88+
def model_fn(x):
89+
return model.signatures['serving_default'](x)['output_1']
90+
91+
elif FLAGS.save_format == 'keras':
92+
model = tf.keras.models.load_model(FLAGS.export_dir)
93+
model_fn = model.__call__
94+
95+
else:
96+
raise NotImplemented
7697

7798
# Get data iterator
7899
iterator = feature.initialize_dataset(batch_size=FLAGS.batch_size,
79100
split='train',
80101
shuffle_size=0,
81102
skips=100000)
82103

83-
# Do tests.
104+
# Test click-ratio
105+
ctr = tf.metrics.Accuracy()
84106
for step in range(num_steps):
85107
features, labels = feature.input_fn(iterator)
86-
probabilities = sig(features)['output_1']
108+
probabilities = model_fn(features)
87109
probabilities = tf.reshape(probabilities, (-1))
88110
preds = tf.cast(tf.round(probabilities), dtype=tf.int32)
89111
labels = tf.cast(labels, dtype=tf.int32)
90-
ctr = tf.metrics.Accuracy()(labels, preds)
91-
print("step: {}, ctr: {}".format(step, ctr))
112+
ctr.update_state(labels, preds)
113+
print("step: {}, ctr: {}".format(step, ctr.result()))
92114

93115

94116
def main(argv):

0 commit comments

Comments
 (0)