Skip to content

Commit afbf309

Browse files
committed
fix github issue
1 parent 962dfb9 commit afbf309

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/training/wiki_filtered_classifier_tuning/multitask_alert_filtered_wiki_classifier_training.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,7 @@ def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_
343343
torch.save(state, latest_checkpoint_path)
344344
logger.info(f"Checkpoint saved for step {global_step} at {latest_checkpoint_path}")
345345

346-
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,
347-
checkpoint_dir='checkpoints', resume=False, save_steps=500, checkpoint_to_load=None):
346+
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,checkpoint_dir='checkpoints', resume=False, save_steps=500, checkpoint_to_load=None):
348347
train_dataset = MultitaskDataset(train_samples, self.tokenizer)
349348
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
350349
val_dataset = MultitaskDataset(val_samples, self.tokenizer)

0 commit comments

Comments
 (0)