@@ -31,7 +31,7 @@ def run_models(dataset_name: str,
31
31
model_name : str ,
32
32
gcs_bucket : str ,
33
33
train_split : str ,
34
- validation_split : Optional [ str ] = None ,
34
+ validation_split : str ,
35
35
one_hot : Optional [bool ] = True ,
36
36
epochs : Optional [int ] = 100 ,
37
37
batch_size : Optional [int ] = 128 ,
@@ -76,7 +76,7 @@ def run_models(dataset_name: str,
76
76
during callbacks.
77
77
4. 'model_checkpoint': the path to the model checkpoints registered
78
78
during callbacks.
79
- 5. 'save_model ': the path to the saved model.
79
+ 5. 'saved_model ': the path to the saved model.
80
80
"""
81
81
model_dirs = get_model_dirs (gcs_bucket , job_name )
82
82
@@ -101,7 +101,7 @@ def get_model_dirs(gcs_bucket, job_name):
101
101
return {'tensorboard_logs' : os .path .join (gcs_base_path , 'logs' ),
102
102
'model_checkpoint' :
103
103
os .path .join (gcs_base_path , 'checkpoints' ),
104
- 'save_model ' : os .path .join (gcs_base_path , 'saved_model' )}
104
+ 'saved_model ' : os .path .join (gcs_base_path , 'saved_model' )}
105
105
106
106
107
107
# TODO(uribejuan): Write function to make sure the input is valid
@@ -120,7 +120,7 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
120
120
if model_name == 'resnet' :
121
121
image_size = 224
122
122
width_ratio = 1
123
- else : # Assumes model_name is efficientnet version
123
+ else : # Assumes model_name is an efficientnet version
124
124
image_size = model .config .resolution
125
125
width_ratio = model .config .width_coefficient
126
126
@@ -132,7 +132,7 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
132
132
callbacks = [
133
133
tf .keras .callbacks .TensorBoard (log_dir = model_dirs ['tensorboard_logs' ]),
134
134
tf .keras .callbacks .ModelCheckpoint (
135
- model_dirs ['model_checkpoint_dir ' ], save_best_only = True ),
135
+ model_dirs ['model_checkpoint ' ], save_best_only = True ),
136
136
tf .keras .callbacks .EarlyStopping (
137
137
monitor = 'loss' , min_delta = 0.001 , patience = 3 )
138
138
]
@@ -145,12 +145,11 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
145
145
146
146
model .fit (
147
147
train_ds ,
148
- validation_split = 0.2 , # TODO(uribejuan): users can modify this split
149
148
validation_data = validation_ds ,
150
149
epochs = epochs ,
151
150
callbacks = callbacks )
152
151
153
- model .save (model_dirs ['save_model_dir ' ])
152
+ model .save (model_dirs ['saved_model ' ])
154
153
155
154
156
155
def load_data_from_builder (builder , train_split , validation_split , image_size ,
@@ -168,7 +167,7 @@ def load_data_from_builder(builder, train_split, validation_split, image_size,
168
167
if validation_split is not None :
169
168
validation_ds = builder .as_dataset (
170
169
validation_split , shuffle_files = True , as_supervised = True )
171
- validation_ds = data_pipeline (validation_split , image_size , width_ratio ,
170
+ validation_ds = data_pipeline (validation_ds , image_size , width_ratio ,
172
171
batch_size , num_classes , one_hot ,
173
172
num_examples )
174
173
else :
0 commit comments