Skip to content

Commit d91eea9

Browse files
juanuribe28Tensorflow Cloud maintainers
authored andcommitted
Fix run_models errors
1. Pass validation_ds instead of validation_split to data_pipeline to solve AttributeError. 2. Remove validation_split arg from model.fit. 3. Update model_dirs keys. PiperOrigin-RevId: 384621093
1 parent f19e308 commit d91eea9

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/python/tensorflow_cloud/core/experimental/models.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run_models(dataset_name: str,
3131
model_name: str,
3232
gcs_bucket: str,
3333
train_split: str,
34-
validation_split: Optional[str] = None,
34+
validation_split: str,
3535
one_hot: Optional[bool] = True,
3636
epochs: Optional[int] = 100,
3737
batch_size: Optional[int] = 128,
@@ -76,7 +76,7 @@ def run_models(dataset_name: str,
7676
during callbacks.
7777
4. 'model_checkpoint': the path to the model checkpoints registered
7878
during callbacks.
79-
5. 'save_model': the path to the saved model.
79+
5. 'saved_model': the path to the saved model.
8080
"""
8181
model_dirs = get_model_dirs(gcs_bucket, job_name)
8282

@@ -101,7 +101,7 @@ def get_model_dirs(gcs_bucket, job_name):
101101
return {'tensorboard_logs': os.path.join(gcs_base_path, 'logs'),
102102
'model_checkpoint':
103103
os.path.join(gcs_base_path, 'checkpoints'),
104-
'save_model': os.path.join(gcs_base_path, 'saved_model')}
104+
'saved_model': os.path.join(gcs_base_path, 'saved_model')}
105105

106106

107107
# TODO(uribejuan): Write function to make sure the input is valid
@@ -120,7 +120,7 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
120120
if model_name == 'resnet':
121121
image_size = 224
122122
width_ratio = 1
123-
else: # Assumes model_name is efficientnet version
123+
else: # Assumes model_name is an efficientnet version
124124
image_size = model.config.resolution
125125
width_ratio = model.config.width_coefficient
126126

@@ -132,7 +132,7 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
132132
callbacks = [
133133
tf.keras.callbacks.TensorBoard(log_dir=model_dirs['tensorboard_logs']),
134134
tf.keras.callbacks.ModelCheckpoint(
135-
model_dirs['model_checkpoint_dir'], save_best_only=True),
135+
model_dirs['model_checkpoint'], save_best_only=True),
136136
tf.keras.callbacks.EarlyStopping(
137137
monitor='loss', min_delta=0.001, patience=3)
138138
]
@@ -145,12 +145,11 @@ def classifier_trainer(dataset_name, model_name, batch_size, epochs,
145145

146146
model.fit(
147147
train_ds,
148-
validation_split=0.2, # TODO(uribejuan): users can modify this split
149148
validation_data=validation_ds,
150149
epochs=epochs,
151150
callbacks=callbacks)
152151

153-
model.save(model_dirs['save_model_dir'])
152+
model.save(model_dirs['saved_model'])
154153

155154

156155
def load_data_from_builder(builder, train_split, validation_split, image_size,
@@ -168,7 +167,7 @@ def load_data_from_builder(builder, train_split, validation_split, image_size,
168167
if validation_split is not None:
169168
validation_ds = builder.as_dataset(
170169
validation_split, shuffle_files=True, as_supervised=True)
171-
validation_ds = data_pipeline(validation_split, image_size, width_ratio,
170+
validation_ds = data_pipeline(validation_ds, image_size, width_ratio,
172171
batch_size, num_classes, one_hot,
173172
num_examples)
174173
else:

src/python/tensorflow_cloud/core/experimental/tests/unit/models_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,20 @@ def test_run_models_locally(self):
143143
'requirements_txt': 'requirements_txt',
144144
'worker_count': 5,}
145145
result = models.run_models('dataset_name', 'model_name', 'gcs_bucket',
146-
'train', **run_kwargs)
146+
'train_split', 'validation_split',
147+
**run_kwargs)
147148
self.remote.assert_called()
148149
self.run.assert_called_with(**run_kwargs)
149150
self.classifier_trainer.assert_not_called()
150151
return_keys = ['job_id', 'docker_image', 'tensorboard_logs',
151-
'model_checkpoint', 'save_model']
152+
'model_checkpoint', 'saved_model']
152153
self.assertListEqual(list(result.keys()), return_keys)
153154

154155
def test_run_models_remote(self):
155156
self.setup_run()
156157
self.setup_run_models()
157158
result = models.run_models('dataset_name', 'model_name', 'gcs_bucket',
158-
'train')
159+
'train_split', 'validation_split')
159160
self.remote.assert_called()
160161
self.run.assert_not_called()
161162
self.classifier_trainer.assert_called()

0 commit comments

Comments
 (0)