Skip to content

Commit 143fd0b

Browse files
lehougoogletensorflower-gardener
authored andcommitted
Minor bug fixes
PiperOrigin-RevId: 422637653
1 parent 871c4e0 commit 143fd0b

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

official/core/base_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ def initialize(self, model: tf.keras.Model):
101101
ckpt_dir_or_file = self.task_config.init_checkpoint
102102
logging.info("Trying to load pretrained checkpoint from %s",
103103
ckpt_dir_or_file)
104-
if tf.io.gfile.isdir(ckpt_dir_or_file):
104+
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
105105
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
106106
if not ckpt_dir_or_file:
107+
logging.info("No checkpoint file found from %s. Will not load.",
108+
ckpt_dir_or_file)
107109
return
108110

109111
if hasattr(model, "checkpoint_items"):

official/nlp/tasks/dual_encoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,13 @@ def validation_step(self,
187187
def initialize(self, model):
188188
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
189189
ckpt_dir_or_file = self.task_config.init_checkpoint
190-
if tf.io.gfile.isdir(ckpt_dir_or_file):
190+
logging.info('Trying to load pretrained checkpoint from %s',
191+
ckpt_dir_or_file)
192+
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
191193
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
192194
if not ckpt_dir_or_file:
195+
logging.info('No checkpoint file found from %s. Will not load.',
196+
ckpt_dir_or_file)
193197
return
194198

195199
pretrain2finetune_mapping = {

official/nlp/tasks/sentence_prediction.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,14 @@ def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
223223
def initialize(self, model):
224224
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
225225
ckpt_dir_or_file = self.task_config.init_checkpoint
226+
logging.info('Trying to load pretrained checkpoint from %s',
227+
ckpt_dir_or_file)
228+
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
229+
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
226230
if not ckpt_dir_or_file:
231+
logging.info('No checkpoint file found from %s. Will not load.',
232+
ckpt_dir_or_file)
227233
return
228-
if tf.io.gfile.isdir(ckpt_dir_or_file):
229-
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
230234

231235
pretrain2finetune_mapping = {
232236
'encoder': model.checkpoint_items['encoder'],

0 commit comments

Comments
 (0)