3535import tensorflow as tf
3636
3737FLAGS = flags .FLAGS
38- flags .DEFINE_string ("data_dir" , "/tmp" ,
39- "Path to download and store movielens data." )
40- flags .DEFINE_string ("output_dir" , None ,
41- "Path to the directory of output files." )
42- flags .DEFINE_bool ("build_movie_vocab" , True ,
43- "If yes, generate sorted movie vocab." )
44- flags .DEFINE_integer ("min_timeline_length" , 3 ,
45- "The minimum timeline length to construct examples." )
46- flags .DEFINE_integer ("max_context_length" , 10 ,
47- "The maximun length of user context history." )
48-
4938# Permalinks to download movielens data.
5039MOVIELENS_1M_URL = "http://files.grouplens.org/datasets/movielens/ml-1m.zip"
5140MOVIELENS_ZIP_FILENAME = "ml-1m.zip"
41+ MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20"
5242MOVIELENS_EXTRACTED_DIR = "ml-1m"
5343RATINGS_FILE_NAME = "ratings.dat"
5444MOVIES_FILE_NAME = "movies.dat"
6050OOV_MOVIE_ID = 0
6151
6252
53+ def define_flags ():
54+ flags .DEFINE_string ("data_dir" , "/tmp" ,
55+ "Path to download and store movielens data." )
56+ flags .DEFINE_string ("output_dir" , None ,
57+ "Path to the directory of output files." )
58+ flags .DEFINE_bool ("build_movie_vocab" , True ,
59+ "If yes, generate sorted movie vocab." )
60+ flags .DEFINE_integer ("min_timeline_length" , 3 ,
61+ "The minimum timeline length to construct examples." )
62+ flags .DEFINE_integer ("max_context_length" , 10 ,
63+ "The maximun length of user context history." )
64+
65+
6366def download_and_extract_data (data_directory , url = MOVIELENS_1M_URL ):
6467 """Download and extract zip containing MovieLens data to a given directory.
6568
@@ -74,6 +77,8 @@ def download_and_extract_data(data_directory, url=MOVIELENS_1M_URL):
7477 path_to_zip = tf .keras .utils .get_file (
7578 fname = MOVIELENS_ZIP_FILENAME ,
7679 origin = url ,
80+ file_hash = MOVIELENS_ZIP_HASH ,
81+ hash_algorithm = "sha256" ,
7782 extract = True ,
7883 cache_dir = data_directory )
7984 extracted_file_dir = os .path .join (
@@ -154,10 +159,13 @@ def generate_examples_from_timelines(timelines,
154159
155160
156161def write_tfrecords (tf_examples , filename ):
157- """Write tf examples to tfrecord file."""
162+ """Writes tf examples to tfrecord file, and returns the count ."""
158163 with tf .io .TFRecordWriter (filename ) as file_writer :
164+ i = 0
159165 for example in tf_examples :
160166 file_writer .write (example )
167+ i += 1
168+ return i
161169
162170
163171def generate_sorted_movie_vocab (movies_df , movie_counts ):
@@ -176,8 +184,9 @@ def write_vocab_json(vocab_movies, filename):
176184 json .dump (vocab_movies , jsonfile , indent = 2 )
177185
178186
179- def main (_ ):
180- data_dir = FLAGS .data_dir
187+ def generate_datasets (data_dir , output_dir , min_timeline_length ,
188+ max_context_length , build_movie_vocab ):
189+ """Generates train and test datasets as TFRecord, and returns stats."""
181190 if not tf .io .gfile .exists (data_dir ):
182191 tf .io .gfile .makedirs (data_dir )
183192
@@ -186,24 +195,37 @@ def main(_):
186195 timelines , movie_counts = convert_to_timelines (ratings_df )
187196 train_examples , test_examples = generate_examples_from_timelines (
188197 timelines = timelines ,
189- min_timeline_len = FLAGS .min_timeline_length ,
190- max_context_len = FLAGS .max_context_length )
191-
192- if not tf .io .gfile .exists (FLAGS .output_dir ):
193- tf .io .gfile .makedirs (FLAGS .output_dir )
194- write_tfrecords (
195- tf_examples = train_examples ,
196- filename = os .path .join (FLAGS .output_dir , OUTPUT_TRAINING_DATA_FILENAME ))
197- write_tfrecords (
198- tf_examples = test_examples ,
199- filename = os .path .join (FLAGS .output_dir , OUTPUT_TESTING_DATA_FILENAME ))
200- if FLAGS .build_movie_vocab :
198+ min_timeline_len = min_timeline_length ,
199+ max_context_len = max_context_length )
200+
201+ if not tf .io .gfile .exists (output_dir ):
202+ tf .io .gfile .makedirs (output_dir )
203+ train_file = os .path .join (output_dir , OUTPUT_TRAINING_DATA_FILENAME )
204+ train_size = write_tfrecords (tf_examples = train_examples , filename = train_file )
205+ test_file = os .path .join (output_dir , OUTPUT_TESTING_DATA_FILENAME )
206+ test_size = write_tfrecords (tf_examples = test_examples , filename = test_file )
207+ stats = {
208+ "train_size" : train_size ,
209+ "test_size" : test_size ,
210+ "train_file" : train_file ,
211+ "test_file" : test_file ,
212+ }
213+ if build_movie_vocab :
201214 vocab_movies = generate_sorted_movie_vocab (
202215 movies_df = movies_df , movie_counts = movie_counts )
203- write_vocab_json (
204- vocab_movies = vocab_movies ,
205- filename = os .path .join (FLAGS .output_dir , OUTPUT_MOVIE_VOCAB_FILENAME ))
216+ vocab_file = os .path .join (output_dir , OUTPUT_MOVIE_VOCAB_FILENAME )
217+ write_vocab_json (vocab_movies = vocab_movies , filename = vocab_file )
218+ stats .update (vocab_size = len (vocab_movies ), vocab_file = vocab_file )
219+ return stats
220+
221+
222+ def main (_ ):
223+ stats = generate_datasets (FLAGS .data_dir , FLAGS .output_dir ,
224+ FLAGS .min_timeline_length , FLAGS .max_context_length ,
225+ FLAGS .build_movie_vocab )
226+ tf .compat .v1 .logging .info ("Generated dataset: %s" , stats )
206227
207228
208229if __name__ == "__main__" :
230+ define_flags ()
209231 app .run (main )
0 commit comments