Skip to content

Commit 023fc2b

Browse files
authored
Add unit test, official flags, and benchmark logs for recommendation model (#4343)
* Add unit test, official flags, and benchmark logs * Fix checking errors * Reorder imports to fix lints * Address comments and correct model layers * Add dataset checking
1 parent 3a624c2 commit 023fc2b

File tree

9 files changed

+524
-183
lines changed

9 files changed

+524
-183
lines changed

official/recommendation/constants.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
TEST_RATINGS_FILENAME = 'test-ratings.csv'
1919
TEST_NEG_FILENAME = 'test-negative.csv'
2020

21-
TRAIN_DATA = 'train_data.csv'
22-
TEST_DATA = 'test_data.csv'
23-
2421
USER = "user_id"
2522
ITEM = "item_id"
2623
RATING = "rating"

official/recommendation/data_download.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@
2121
from __future__ import division
2222
from __future__ import print_function
2323

24-
import argparse
2524
import collections
2625
import os
2726
import sys
2827
import time
2928
import zipfile
3029

30+
# pylint: disable=g-bad-import-order
3131
import numpy as np
3232
import pandas as pd
3333
from six.moves import urllib # pylint: disable=redefined-builtin
34+
from absl import app as absl_app
35+
from absl import flags
3436
import tensorflow as tf
37+
# pylint: enable=g-bad-import-order
3538

36-
from official.recommendation import constants # pylint: disable=g-bad-import-order
39+
from official.recommendation import constants
40+
from official.utils.flags import core as flags_core
3741

3842
# URL to download dataset
3943
_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
@@ -306,6 +310,10 @@ def main(_):
306310

307311
make_dir(FLAGS.data_dir)
308312

313+
assert FLAGS.dataset, (
314+
"Please specify which dataset to download. "
315+
"Two datasets are available: ml-1m and ml-20m.")
316+
309317
# Download the zip dataset
310318
dataset_zip = FLAGS.dataset + ".zip"
311319
file_path = os.path.join(FLAGS.data_dir, dataset_zip)
@@ -335,14 +343,23 @@ def _progress(count, block_size, total_size):
335343
parse_file_to_csv(FLAGS.data_dir, FLAGS.dataset)
336344

337345

346+
def define_data_download_flags():
347+
"""Add flags specifying data download arguments."""
348+
flags.DEFINE_string(
349+
name="data_dir", default="/tmp/movielens-data/",
350+
help=flags_core.help_wrap(
351+
"Directory to download and extract data."))
352+
353+
flags.DEFINE_enum(
354+
name="dataset", default=None,
355+
enum_values=["ml-1m", "ml-20m"], case_sensitive=False,
356+
help=flags_core.help_wrap(
357+
"Dataset to be trained and evaluated. Two datasets are available "
358+
": ml-1m and ml-20m."))
359+
360+
338361
if __name__ == "__main__":
339-
parser = argparse.ArgumentParser()
340-
parser.add_argument(
341-
"--data_dir", type=str, default="/tmp/movielens-data/",
342-
help="Directory to download data and extract the zip.")
343-
parser.add_argument(
344-
"--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"],
345-
help="Dataset to be trained and evaluated.")
346-
347-
FLAGS, unparsed = parser.parse_known_args()
348-
tf.app.run(argv=[sys.argv[0]] + unparsed)
362+
tf.logging.set_verbosity(tf.logging.INFO)
363+
define_data_download_flags()
364+
FLAGS = flags.FLAGS
365+
absl_app.run(main)

official/recommendation/dataset.py

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@
1717
Load the training dataset and evaluation dataset from csv file into memory.
1818
Prepare input for model training and evaluation.
1919
"""
20-
import time
21-
2220
import numpy as np
2321
from six.moves import xrange # pylint: disable=redefined-builtin
2422
import tensorflow as tf
2523

2624
from official.recommendation import constants # pylint: disable=g-bad-import-order
2725

28-
# The column names and types of csv file
29-
_CSV_COLUMN_NAMES = [constants.USER, constants.ITEM, constants.RATING]
30-
_CSV_TYPES = [[0], [0], [0]]
31-
3226
# The buffer size for shuffling train dataset.
3327
_SHUFFLE_BUFFER_SIZE = 1024
3428

@@ -37,7 +31,7 @@ class NCFDataSet(object):
3731
"""A class containing data information for model training and evaluation."""
3832

3933
def __init__(self, train_data, num_users, num_items, num_negatives,
40-
true_items, all_items):
34+
true_items, all_items, all_eval_data):
4135
"""Initialize NCFDataset class.
4236
4337
Args:
@@ -50,17 +44,19 @@ def __init__(self, train_data, num_users, num_items, num_negatives,
5044
evaluation. Each entry is a latest positive instance for one user.
5145
all_items: A nested list, all items for evaluation, and each entry is the
5246
evaluation items for one user.
47+
all_eval_data: A numpy array of eval/test dataset.
5348
"""
5449
self.train_data = train_data
5550
self.num_users = num_users
5651
self.num_items = num_items
5752
self.num_negatives = num_negatives
5853
self.eval_true_items = true_items
5954
self.eval_all_items = all_items
55+
self.all_eval_data = all_eval_data
6056

6157

6258
def load_data(file_name):
63-
"""Load data from a csv file which splits on \t."""
59+
"""Load data from a csv file which splits on tab key."""
6460
lines = tf.gfile.Open(file_name, "r").readlines()
6561

6662
# Process the file line by line
@@ -122,13 +118,11 @@ def data_preprocessing(train_fname, test_fname, test_neg_fname, num_negatives):
122118
all_items.append(items) # All items (including positive and negative items)
123119
all_test_data.extend(users_items) # Generate test dataset
124120

125-
# Save test dataset into csv file
126-
np.savetxt(constants.TEST_DATA, np.asarray(all_test_data).astype(int),
127-
fmt="%i", delimiter=",")
128-
129121
# Create NCFDataset object
130122
ncf_dataset = NCFDataSet(
131-
train_data, num_users, num_items, num_negatives, true_items, all_items)
123+
train_data, num_users, num_items, num_negatives, true_items, all_items,
124+
np.asarray(all_test_data)
125+
)
132126

133127
return ncf_dataset
134128

@@ -144,6 +138,9 @@ def generate_train_dataset(train_data, num_items, num_negatives):
144138
num_items: An integer, the number of items in positive training instances.
145139
num_negatives: An integer, the number of negative training instances
146140
following positive training instances. It is 4 by default.
141+
142+
Returns:
143+
A numpy array of training dataset.
147144
"""
148145
all_train_data = []
149146
# A set with user-item tuples
@@ -158,13 +155,10 @@ def generate_train_dataset(train_data, num_items, num_negatives):
158155
j = np.random.randint(num_items)
159156
all_train_data.append([u, j, 0])
160157

161-
# Save the train dataset into a csv file
162-
np.savetxt(constants.TRAIN_DATA, np.asarray(all_train_data).astype(int),
163-
fmt="%i", delimiter=",")
158+
return np.asarray(all_train_data)
164159

165160

166-
def input_fn(training, batch_size, repeat=1, ncf_dataset=None,
167-
num_parallel_calls=1):
161+
def input_fn(training, batch_size, ncf_dataset, repeat=1):
168162
"""Input function for model training and evaluation.
169163
170164
The train input consists of 1 positive instance (user and item have
@@ -176,55 +170,39 @@ def input_fn(training, batch_size, repeat=1, ncf_dataset=None,
176170
Args:
177171
training: A boolean flag for training mode.
178172
batch_size: An integer, batch size for training and evaluation.
173+
ncf_dataset: An NCFDataSet object, which contains the information about
174+
training and test data.
179175
repeat: An integer, how many times to repeat the dataset.
180-
ncf_dataset: An NCFDataSet object, which contains the information to
181-
generate negative training instances.
182-
num_parallel_calls: An integer, number of cpu cores for parallel input
183-
processing.
184176
185177
Returns:
186178
dataset: A tf.data.Dataset object containing examples loaded from the files.
187179
"""
188-
# Default test file name
189-
file_name = constants.TEST_DATA
190-
191180
# Generate random negative instances for training in each epoch
192181
if training:
193-
t1 = time.time()
194-
generate_train_dataset(
182+
train_data = generate_train_dataset(
195183
ncf_dataset.train_data, ncf_dataset.num_items,
196184
ncf_dataset.num_negatives)
197-
file_name = constants.TRAIN_DATA
198-
tf.logging.info(
199-
"Generating training instances: {:.1f}s".format(time.time() - t1))
200-
201-
# Create a dataset containing the text lines.
202-
dataset = tf.data.TextLineDataset(file_name)
203-
204-
# Test dataset only has two fields while train dataset has three
205-
num_cols = len(_CSV_COLUMN_NAMES) - 1
206-
# Shuffle the dataset for training
207-
if training:
185+
# Get train features and labels
186+
train_features = [
187+
(constants.USER, np.expand_dims(train_data[:, 0], axis=1)),
188+
(constants.ITEM, np.expand_dims(train_data[:, 1], axis=1))
189+
]
190+
train_labels = [
191+
(constants.RATING, np.expand_dims(train_data[:, 2], axis=1))]
192+
193+
dataset = tf.data.Dataset.from_tensor_slices(
194+
(dict(train_features), dict(train_labels))
195+
)
208196
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER_SIZE)
209-
num_cols += 1
210-
211-
def _parse_csv(line):
212-
"""Parse each line of the csv file."""
213-
# Decode the line into its fields
214-
fields = tf.decode_csv(line, record_defaults=_CSV_TYPES[0:num_cols])
215-
fields = [tf.expand_dims(field, axis=0) for field in fields]
216-
217-
# Pack the result into a dictionary
218-
features = dict(zip(_CSV_COLUMN_NAMES[0:num_cols], fields))
219-
# Separate the labels from the features for training
220-
if training:
221-
labels = features.pop(constants.RATING)
222-
return features, labels
223-
# Return features only for test/prediction
224-
return features
225-
226-
# Parse each line into a dictionary
227-
dataset = dataset.map(_parse_csv, num_parallel_calls=num_parallel_calls)
197+
else:
198+
# Create eval/test dataset
199+
test_user = ncf_dataset.all_eval_data[:, 0]
200+
test_item = ncf_dataset.all_eval_data[:, 1]
201+
test_features = [
202+
(constants.USER, np.expand_dims(test_user, axis=1)),
203+
(constants.ITEM, np.expand_dims(test_item, axis=1))]
204+
205+
dataset = tf.data.Dataset.from_tensor_slices(dict(test_features))
228206

229207
# Repeat and batch the dataset
230208
dataset = dataset.repeat(repeat)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Unit tests for dataset.py."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
22+
import numpy as np
23+
import tensorflow as tf # pylint: disable=g-bad-import-order
24+
25+
from official.recommendation import dataset
26+
27+
_TRAIN_FNAME = os.path.join(
28+
os.path.dirname(__file__), "unittest_data/test_train_ratings.csv")
29+
_TEST_FNAME = os.path.join(
30+
os.path.dirname(__file__), "unittest_data/test_eval_ratings.csv")
31+
_TEST_NEG_FNAME = os.path.join(
32+
os.path.dirname(__file__), "unittest_data/test_eval_negative.csv")
33+
_NUM_NEG = 4
34+
35+
36+
class DatasetTest(tf.test.TestCase):
37+
38+
def test_load_data(self):
39+
data = dataset.load_data(_TEST_FNAME)
40+
self.assertEqual(len(data), 2)
41+
42+
self.assertEqual(data[0][0], 0)
43+
self.assertEqual(data[0][2], 1)
44+
45+
self.assertEqual(data[-1][0], 1)
46+
self.assertEqual(data[-1][2], 1)
47+
48+
def test_data_preprocessing(self):
49+
ncf_dataset = dataset.data_preprocessing(
50+
_TRAIN_FNAME, _TEST_FNAME, _TEST_NEG_FNAME, _NUM_NEG)
51+
52+
# Check train data preprocessing
53+
self.assertAllEqual(np.array(ncf_dataset.train_data)[:, 2],
54+
np.full(len(ncf_dataset.train_data), 1))
55+
self.assertEqual(ncf_dataset.num_users, 2)
56+
self.assertEqual(ncf_dataset.num_items, 175)
57+
58+
# Check test dataset
59+
test_dataset = ncf_dataset.all_eval_data
60+
first_true_item = test_dataset[100]
61+
self.assertEqual(first_true_item[1], ncf_dataset.eval_true_items[0])
62+
self.assertEqual(first_true_item[1], ncf_dataset.eval_all_items[0][-1])
63+
64+
last_gt_item = test_dataset[-1]
65+
self.assertEqual(last_gt_item[1], ncf_dataset.eval_true_items[-1])
66+
self.assertEqual(last_gt_item[1], ncf_dataset.eval_all_items[-1][-1])
67+
68+
test_list = test_dataset.tolist()
69+
70+
first_test_items = [x[1] for x in test_list if x[0] == 0]
71+
self.assertAllEqual(first_test_items, ncf_dataset.eval_all_items[0])
72+
73+
last_test_items = [x[1] for x in test_list if x[0] == 1]
74+
self.assertAllEqual(last_test_items, ncf_dataset.eval_all_items[-1])
75+
76+
def test_generate_train_dataset(self):
77+
# Check train dataset
78+
ncf_dataset = dataset.data_preprocessing(
79+
_TRAIN_FNAME, _TEST_FNAME, _TEST_NEG_FNAME, _NUM_NEG)
80+
81+
train_dataset = dataset.generate_train_dataset(
82+
ncf_dataset.train_data, ncf_dataset.num_items, _NUM_NEG)
83+
84+
# Each user has 1 positive instance followed by _NUM_NEG negative instances
85+
train_data_0 = train_dataset[0]
86+
self.assertEqual(train_data_0[2], 1)
87+
for i in range(1, _NUM_NEG + 1):
88+
train_data = train_dataset[i]
89+
self.assertEqual(train_data_0[0], train_data[0])
90+
self.assertNotEqual(train_data_0[1], train_data[1])
91+
self.assertEqual(0, train_data[2])
92+
93+
train_data_last = train_dataset[-1 - _NUM_NEG]
94+
self.assertEqual(train_data_last[2], 1)
95+
for i in range(-1, -_NUM_NEG):
96+
train_data = train_dataset[i]
97+
self.assertEqual(train_data_last[0], train_data[0])
98+
self.assertNotEqual(train_data_last[1], train_data[1])
99+
self.assertEqual(0, train_data[2])
100+
101+
102+
if __name__ == "__main__":
103+
tf.test.main()

0 commit comments

Comments
 (0)