Skip to content

Commit 22f86fd

Browse files
lintian06copybara-github
authored andcommitted
Refactor recommendation ml code.
PiperOrigin-RevId: 348577962
1 parent 117b0bb commit 22f86fd

File tree

3 files changed

+66
-37
lines changed

3 files changed

+66
-37
lines changed

lite/examples/recommendation/ml/data/example_generation_movielens.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,10 @@
3535
import tensorflow as tf
3636

3737
FLAGS = flags.FLAGS
38-
flags.DEFINE_string("data_dir", "/tmp",
39-
"Path to download and store movielens data.")
40-
flags.DEFINE_string("output_dir", None,
41-
"Path to the directory of output files.")
42-
flags.DEFINE_bool("build_movie_vocab", True,
43-
"If yes, generate sorted movie vocab.")
44-
flags.DEFINE_integer("min_timeline_length", 3,
45-
"The minimum timeline length to construct examples.")
46-
flags.DEFINE_integer("max_context_length", 10,
47-
"The maximun length of user context history.")
48-
4938
# Permalinks to download movielens data.
5039
MOVIELENS_1M_URL = "http://files.grouplens.org/datasets/movielens/ml-1m.zip"
5140
MOVIELENS_ZIP_FILENAME = "ml-1m.zip"
41+
MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20"
5242
MOVIELENS_EXTRACTED_DIR = "ml-1m"
5343
RATINGS_FILE_NAME = "ratings.dat"
5444
MOVIES_FILE_NAME = "movies.dat"
@@ -60,6 +50,19 @@
6050
OOV_MOVIE_ID = 0
6151

6252

53+
def define_flags():
54+
flags.DEFINE_string("data_dir", "/tmp",
55+
"Path to download and store movielens data.")
56+
flags.DEFINE_string("output_dir", None,
57+
"Path to the directory of output files.")
58+
flags.DEFINE_bool("build_movie_vocab", True,
59+
"If yes, generate sorted movie vocab.")
60+
flags.DEFINE_integer("min_timeline_length", 3,
61+
"The minimum timeline length to construct examples.")
62+
flags.DEFINE_integer("max_context_length", 10,
63+
"The maximun length of user context history.")
64+
65+
6366
def download_and_extract_data(data_directory, url=MOVIELENS_1M_URL):
6467
"""Download and extract zip containing MovieLens data to a given directory.
6568
@@ -74,6 +77,8 @@ def download_and_extract_data(data_directory, url=MOVIELENS_1M_URL):
7477
path_to_zip = tf.keras.utils.get_file(
7578
fname=MOVIELENS_ZIP_FILENAME,
7679
origin=url,
80+
file_hash=MOVIELENS_ZIP_HASH,
81+
hash_algorithm="sha256",
7782
extract=True,
7883
cache_dir=data_directory)
7984
extracted_file_dir = os.path.join(
@@ -154,10 +159,13 @@ def generate_examples_from_timelines(timelines,
154159

155160

156161
def write_tfrecords(tf_examples, filename):
157-
"""Write tf examples to tfrecord file."""
162+
"""Writes tf examples to tfrecord file, and returns the count."""
158163
with tf.io.TFRecordWriter(filename) as file_writer:
164+
i = 0
159165
for example in tf_examples:
160166
file_writer.write(example)
167+
i += 1
168+
return i
161169

162170

163171
def generate_sorted_movie_vocab(movies_df, movie_counts):
@@ -176,8 +184,9 @@ def write_vocab_json(vocab_movies, filename):
176184
json.dump(vocab_movies, jsonfile, indent=2)
177185

178186

179-
def main(_):
180-
data_dir = FLAGS.data_dir
187+
def generate_datasets(data_dir, output_dir, min_timeline_length,
188+
max_context_length, build_movie_vocab):
189+
"""Generates train and test datasets as TFRecord, and returns stats."""
181190
if not tf.io.gfile.exists(data_dir):
182191
tf.io.gfile.makedirs(data_dir)
183192

@@ -186,24 +195,37 @@ def main(_):
186195
timelines, movie_counts = convert_to_timelines(ratings_df)
187196
train_examples, test_examples = generate_examples_from_timelines(
188197
timelines=timelines,
189-
min_timeline_len=FLAGS.min_timeline_length,
190-
max_context_len=FLAGS.max_context_length)
191-
192-
if not tf.io.gfile.exists(FLAGS.output_dir):
193-
tf.io.gfile.makedirs(FLAGS.output_dir)
194-
write_tfrecords(
195-
tf_examples=train_examples,
196-
filename=os.path.join(FLAGS.output_dir, OUTPUT_TRAINING_DATA_FILENAME))
197-
write_tfrecords(
198-
tf_examples=test_examples,
199-
filename=os.path.join(FLAGS.output_dir, OUTPUT_TESTING_DATA_FILENAME))
200-
if FLAGS.build_movie_vocab:
198+
min_timeline_len=min_timeline_length,
199+
max_context_len=max_context_length)
200+
201+
if not tf.io.gfile.exists(output_dir):
202+
tf.io.gfile.makedirs(output_dir)
203+
train_file = os.path.join(output_dir, OUTPUT_TRAINING_DATA_FILENAME)
204+
train_size = write_tfrecords(tf_examples=train_examples, filename=train_file)
205+
test_file = os.path.join(output_dir, OUTPUT_TESTING_DATA_FILENAME)
206+
test_size = write_tfrecords(tf_examples=test_examples, filename=test_file)
207+
stats = {
208+
"train_size": train_size,
209+
"test_size": test_size,
210+
"train_file": train_file,
211+
"test_file": test_file,
212+
}
213+
if build_movie_vocab:
201214
vocab_movies = generate_sorted_movie_vocab(
202215
movies_df=movies_df, movie_counts=movie_counts)
203-
write_vocab_json(
204-
vocab_movies=vocab_movies,
205-
filename=os.path.join(FLAGS.output_dir, OUTPUT_MOVIE_VOCAB_FILENAME))
216+
vocab_file = os.path.join(output_dir, OUTPUT_MOVIE_VOCAB_FILENAME)
217+
write_vocab_json(vocab_movies=vocab_movies, filename=vocab_file)
218+
stats.update(vocab_size=len(vocab_movies), vocab_file=vocab_file)
219+
return stats
220+
221+
222+
def main(_):
223+
stats = generate_datasets(FLAGS.data_dir, FLAGS.output_dir,
224+
FLAGS.min_timeline_length, FLAGS.max_context_length,
225+
FLAGS.build_movie_vocab)
226+
tf.compat.v1.logging.info("Generated dataset: %s", stats)
206227

207228

208229
if __name__ == "__main__":
230+
define_flags()
209231
app.run(main)

lite/examples/recommendation/ml/model/recommendation_model_launcher_keras.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def on_epoch_end(self, epoch, logs=None):
7676
self.checkpoint_manager.save(checkpoint_number=step_counter)
7777

7878

79-
def get_input_fn(data_filepattern):
79+
def get_input_fn(data_filepattern, batch_size):
8080
"""Get input_fn for recommendation model estimator."""
8181

8282
def decode_example(serialized_proto):
@@ -112,7 +112,7 @@ def input_fn():
112112
d = d.repeat()
113113
d = d.shuffle(buffer_size=100)
114114
d = d.map(decode_example)
115-
d = d.batch(FLAGS.batch_size, drop_remainder=True)
115+
d = d.batch(batch_size, drop_remainder=True)
116116
d = d.prefetch(1)
117117
return d
118118

@@ -226,8 +226,9 @@ def main(_):
226226
params['num_predictions'] = FLAGS.num_predictions
227227

228228
logger.info('Setting up train and eval input_fns.')
229-
train_input_fn = get_input_fn(FLAGS.training_data_filepattern)
230-
eval_input_fn = get_input_fn(FLAGS.testing_data_filepattern)
229+
train_input_fn = get_input_fn(FLAGS.training_data_filepattern,
230+
FLAGS.batch_size)
231+
eval_input_fn = get_input_fn(FLAGS.testing_data_filepattern, FLAGS.batch_size)
231232

232233
logger.info('Build keras model for mode: {}.'.format(FLAGS.run_mode))
233234
model = build_keras_model(params=params)

lite/examples/recommendation/ml/model/recommendation_model_launcher_keras_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ def setUp(self):
7373
FLAGS.encoder_type = 'cnn'
7474
FLAGS.num_predictions = 10
7575
FLAGS.max_history_length = 10
76+
FLAGS.batch_size = 1
7677

7778
def testModelFnTrainModeExecute(self):
7879
"""Verifies that 'model_fn' can be executed in train and eval mode."""
7980
self.params['encoder_type'] = FLAGS.encoder_type
80-
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern)
81-
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern)
81+
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern,
82+
FLAGS.batch_size)
83+
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern,
84+
FLAGS.batch_size)
8285
model = launcher.build_keras_model(params=self.params)
8386
launcher.train_and_eval(
8487
model=model,
@@ -96,8 +99,10 @@ def testModelFnExportModeExecute(self):
9699
"""Verifies model can be exported to savedmodel and tflite model."""
97100
self.params['encoder_type'] = FLAGS.encoder_type
98101
self.params['num_predictions'] = FLAGS.num_predictions
99-
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern)
100-
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern)
102+
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern,
103+
FLAGS.batch_size)
104+
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern,
105+
FLAGS.batch_size)
101106
model = launcher.build_keras_model(params=self.params)
102107
launcher.train_and_eval(
103108
model=model,
@@ -142,4 +147,5 @@ def testModelFnExportModeExecute(self):
142147

143148

144149
if __name__ == '__main__':
150+
launcher.define_flags()
145151
tf.test.main()

0 commit comments

Comments
 (0)