Skip to content

Commit 962dfb9

Browse files
committed
fix github issue
1 parent 7b3db51 commit 962dfb9

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

src/training/wiki_filtered_classifier_tuning/multitask_alert_filtered_wiki_classifier_training.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
import os
4-
from collections import defaultdic
4+
from collections import defaultdict
55
from pathlib import Path
66
import random
77
import argparse
@@ -14,9 +14,9 @@
1414
import requests
1515
import torch
1616
import torch.nn as nn
17-
from datasets import load_datase
18-
from sklearn.model_selection import train_test_spli
19-
from torch.utils.data import DataLoader, Datase
17+
from datasets import load_dataset
18+
from sklearn.model_selection import train_test_split
19+
from torch.utils.data import DataLoader, Dataset
2020
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
2121
import wikipediaapi
2222

@@ -37,7 +37,7 @@
3737

3838

3939
WIKI_DATA_DIR = "wikipedia_data"
40-
ARTICLES_PER_CATEGORY = 10000
40+
ARTICLES_PER_CATEGORY = 10000
4141
PAGEVIEWS_API_USER_AGENT = "MyMultitaskProject/1.0 ([email protected]) requests/2.0"
4242

4343

@@ -90,7 +90,7 @@ def __init__(self, model, tokenizer, task_configs, device="cuda"):
9090
self.device = device
9191
self.jailbreak_label_mapping = None
9292
self.loss_fns = {task: nn.CrossEntropyLoss() for task in task_configs if task_configs}
93-
93+
9494
self.EXCLUDED_CATEGORIES = {
9595
"fiction", "culture", "arts", "comics", "media", "entertainment",
9696
"people", "religion", "sports", "society", "geography", "history"
@@ -104,16 +104,16 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
104104
if visited_categories is None:
105105
visited_categories = set()
106106

107-
if depth >= max_depth or len(visited_categories) > max_articles * 2:
107+
if depth >= max_depth or len(visited_categories) > max_articles * 2:
108108
return []
109-
109+
110110
articles = []
111111
visited_categories.add(category_page.title)
112-
112+
113113
members = list(category_page.categorymembers.values())
114114
for member in members:
115115
if len(articles) >= max_articles: break
116-
116+
117117
if member.ns == wikipediaapi.Namespace.CATEGORY and member.title not in visited_categories:
118118
member_title_lower = member.title.lower()
119119
if not any(excluded in member_title_lower for excluded in self.EXCLUDED_CATEGORIES):
@@ -126,7 +126,7 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
126126
))
127127
elif member.ns == wikipediaapi.Namespace.MAIN:
128128
articles.append(member.title)
129-
129+
130130
return list(set(articles))
131131

132132
def _get_pageviews(self, article_title, session, start_date, end_date):
@@ -153,7 +153,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
153153
logger.info("--- Starting Wikipedia Data Download and Preparation ---")
154154
os.makedirs(data_dir, exist_ok=True)
155155
wiki_wiki = wikipediaapi.Wikipedia('MyMultitaskProject ([email protected])', 'en')
156-
156+
157157
end_date = datetime.utcnow()
158158
start_date = end_date - timedelta(days=90)
159159
end_date_str = end_date.strftime('%Y%m%d')
@@ -172,9 +172,9 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
172172

173173
num_candidates = articles_per_category * candidate_multiplier
174174
logger.info(f" Recursively searching for up to {num_candidates} candidate articles (max depth=4)...")
175-
175+
176176
candidate_titles = self._get_wiki_articles_recursive(cat_page, num_candidates)
177-
177+
178178
logger.info(f" Found {len(candidate_titles)} unique candidate articles.")
179179

180180
if not candidate_titles:
@@ -202,7 +202,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
202202
for title in tqdm(article_titles, desc=f" Downloading '{category}'"):
203203
safe_filename = "".join([c for c in title if c.isalpha() or c.isdigit() or c==' ']).rstrip() + ".txt"
204204
file_path = os.path.join(category_path, safe_filename)
205-
205+
206206
if os.path.exists(file_path): continue
207207

208208
page = wiki_wiki.page(title)
@@ -213,7 +213,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
213213
saved_count += 1
214214
except Exception as e:
215215
logger.error(f" Could not save article '{title}': {e}")
216-
216+
217217
logger.info(f" Saved {saved_count} new articles for '{category}'.")
218218

219219
logger.info("\n--- Wikipedia Data Download Complete ---")
@@ -303,7 +303,7 @@ def _load_jailbreak_dataset(self):
303303
# Using the 'en' (English) configuration
304304
jb_dataset2 = load_dataset("Babelscape/ALERT", "alert")
305305
texts2 = jb_dataset2['test']['prompt']
306-
# Prefix labels to avoid conflicts with the other datase
306+
# Prefix labels to avoid conflicts with the other dataset
307307
labels2 = ["ALERT:" + str(l) for l in jb_dataset2['test']['category']]
308308
all_texts.extend(texts2)
309309
all_labels.extend(labels2)
@@ -313,21 +313,21 @@ def _load_jailbreak_dataset(self):
313313
logger.info(f"Total loaded jailbreak samples: {len(all_texts)}")
314314
unique_labels = sorted(list(set(all_labels)))
315315
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
316-
316+
317317
logger.info(f"Jailbreak label distribution: {dict(sorted([(label, all_labels.count(label)) for label in set(all_labels)], key=lambda x: x[1], reverse=True))}")
318318
logger.info(f"Total unique jailbreak labels: {len(unique_labels)}")
319-
319+
320320
label_indices = [label_to_idx[label] for label in all_labels]
321321
samples = list(zip(all_texts, label_indices))
322-
322+
323323
self.jailbreak_label_mapping = {"label_to_idx": label_to_idx, "idx_to_label": {idx: label for label, idx in label_to_idx.items()}}
324324
return samples
325325
except Exception as e:
326326
import traceback
327327
logger.error(f"Failed to load jailbreak datasets: {e}")
328328
traceback.print_exc()
329329
return []
330-
330+
331331
def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings):
332332
os.makedirs(checkpoint_dir, exist_ok=True)
333333
latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
@@ -343,7 +343,7 @@ def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_
343343
torch.save(state, latest_checkpoint_path)
344344
logger.info(f"Checkpoint saved for step {global_step} at {latest_checkpoint_path}")
345345

346-
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,
346+
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,
347347
checkpoint_dir='checkpoints', resume=False, save_steps=500, checkpoint_to_load=None):
348348
train_dataset = MultitaskDataset(train_samples, self.tokenizer)
349349
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
@@ -372,7 +372,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
372372
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
373373
for step, batch in pbar:
374374
if steps_to_skip > 0 and step < steps_to_skip: continue
375-
375+
376376
optimizer.zero_grad()
377377
outputs = self.model(
378378
input_ids=batch["input_ids"].to(self.device),
@@ -383,13 +383,13 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
383383
task_logits = outputs[task_name][i : i + 1]
384384
task_label = batch["label"][i : i + 1].to(self.device)
385385
task_weight = self.task_configs[task_name].get("weight", 1.0)
386-
batch_loss += self.loss_fns[task_name](task_logits, task_label) * task_weigh
386+
batch_loss += self.loss_fns[task_name](task_logits, task_label) * task_weight
387387
batch_loss.backward()
388388
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
389389
optimizer.step(); scheduler.step(); global_step += 1
390390
if global_step > 0 and global_step % save_steps == 0:
391391
self._save_checkpoint(epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
392-
392+
393393
if val_loader: self.evaluate(val_loader)
394394
self._save_checkpoint(epoch + 1, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
395395
steps_to_skip = 0
@@ -443,13 +443,13 @@ def main():
443443
logger.info("--- Starting Model Training ---")
444444

445445
task_configs, label_mappings, checkpoint_to_load = {}, {}, None
446-
446+
447447
if args.resume:
448448
latest_checkpoint_path = os.path.join(args.checkpoint_dir, 'latest_checkpoint.pt')
449449
if os.path.exists(latest_checkpoint_path):
450450
logger.info(f"Resuming training from checkpoint: {latest_checkpoint_path}")
451451
checkpoint_to_load = torch.load(latest_checkpoint_path, map_location=device)
452-
452+
453453
task_configs = checkpoint_to_load.get('task_configs')
454454
label_mappings = checkpoint_to_load.get('label_mappings')
455455

@@ -458,10 +458,10 @@ def main():
458458
logger.warning("Cannot safely resume. Starting a fresh training run.")
459459
args.resume = False
460460
checkpoint_to_load = None
461-
task_configs = {}
461+
task_configs = {}
462462
else:
463463
logger.info("Loaded model configuration from checkpoint.")
464-
464+
465465
else:
466466
logger.warning(f"Resume flag is set, but no checkpoint found in '{args.checkpoint_dir}'. Starting fresh run.")
467467
args.resume = False
@@ -476,12 +476,12 @@ def main():
476476
task_configs["pii"] = {"num_classes": len(label_mappings["pii"]["label_mapping"]["label_to_idx"]), "weight": 3.0}
477477
if "jailbreak" in label_mappings:
478478
task_configs["jailbreak"] = {"num_classes": len(label_mappings["jailbreak"]["label_mapping"]["label_to_idx"]), "weight": 2.0}
479-
479+
480480
if not task_configs:
481481
logger.error("No tasks configured. Exiting."); return
482482

483483
logger.info(f"Final task configurations: {task_configs}")
484-
484+
485485
model = MultitaskBertModel(base_model_name, task_configs)
486486

487487
if args.resume and checkpoint_to_load:
@@ -497,9 +497,9 @@ def main():
497497
active_label_mappings = label_mappings if (args.resume and label_mappings) else final_label_mappings
498498

499499
trainer = MultitaskTrainer(model, tokenizer, task_configs, device)
500-
500+
501501
logger.info(f"Total training samples: {len(train_samples)}")
502-
502+
503503
trainer.train(
504504
train_samples, val_samples, active_label_mappings,
505505
num_epochs=10, batch_size=16,
@@ -508,7 +508,7 @@ def main():
508508
save_steps=args.save_steps,
509509
checkpoint_to_load=checkpoint_to_load
510510
)
511-
511+
512512
trainer.save_model(output_path)
513513
with open(os.path.join(output_path, "label_mappings.json"), "w") as f:
514514
json.dump(active_label_mappings, f, indent=2)

0 commit comments

Comments
 (0)