Skip to content

Commit 063941d

Browse files
committed
fix github issue
1 parent 4134924 commit 063941d

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

src/training/wiki_filtered_classifier_tuning/multitask_alert_filtered_wiki_classifier_training.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import json
22
import logging
33
import os
4-
from collections import defaultdict
5-
from pathlib import Path
64
import random
75
import argparse
8-
from tqdm import tqdm
96
import time
10-
from datetime import datetime, timedelta
117
import urllib.parse
8+
from collections import defaultdict
9+
from pathlib import Path
10+
from datetime import datetime, timedelta
1211

1312
import numpy as np
1413
import requests
@@ -37,7 +36,7 @@
3736

3837

3938
WIKI_DATA_DIR = "wikipedia_data"
40-
ARTICLES_PER_CATEGORY = 10000
39+
ARTICLES_PER_CATEGORY = 10000
4140
PAGEVIEWS_API_USER_AGENT = "MyMultitaskProject/1.0 ([email protected]) requests/2.0"
4241

4342

@@ -90,7 +89,7 @@ def __init__(self, model, tokenizer, task_configs, device="cuda"):
9089
self.device = device
9190
self.jailbreak_label_mapping = None
9291
self.loss_fns = {task: nn.CrossEntropyLoss() for task in task_configs if task_configs}
93-
92+
9493
self.EXCLUDED_CATEGORIES = {
9594
"fiction", "culture", "arts", "comics", "media", "entertainment",
9695
"people", "religion", "sports", "society", "geography", "history"
@@ -104,16 +103,16 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
104103
if visited_categories is None:
105104
visited_categories = set()
106105

107-
if depth >= max_depth or len(visited_categories) > max_articles * 2:
106+
if depth >= max_depth or len(visited_categories) > max_articles * 2:
108107
return []
109-
108+
110109
articles = []
111110
visited_categories.add(category_page.title)
112-
111+
113112
members = list(category_page.categorymembers.values())
114113
for member in members:
115114
if len(articles) >= max_articles: break
116-
115+
117116
if member.ns == wikipediaapi.Namespace.CATEGORY and member.title not in visited_categories:
118117
member_title_lower = member.title.lower()
119118
if not any(excluded in member_title_lower for excluded in self.EXCLUDED_CATEGORIES):
@@ -126,7 +125,7 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
126125
))
127126
elif member.ns == wikipediaapi.Namespace.MAIN:
128127
articles.append(member.title)
129-
128+
130129
return list(set(articles))
131130

132131
def _get_pageviews(self, article_title, session, start_date, end_date):
@@ -153,7 +152,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
153152
logger.info("--- Starting Wikipedia Data Download and Preparation ---")
154153
os.makedirs(data_dir, exist_ok=True)
155154
wiki_wiki = wikipediaapi.Wikipedia('MyMultitaskProject ([email protected])', 'en')
156-
155+
157156
end_date = datetime.utcnow()
158157
start_date = end_date - timedelta(days=90)
159158
end_date_str = end_date.strftime('%Y%m%d')
@@ -172,9 +171,9 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
172171

173172
num_candidates = articles_per_category * candidate_multiplier
174173
logger.info(f" Recursively searching for up to {num_candidates} candidate articles (max depth=4)...")
175-
174+
176175
candidate_titles = self._get_wiki_articles_recursive(cat_page, num_candidates)
177-
176+
178177
logger.info(f" Found {len(candidate_titles)} unique candidate articles.")
179178

180179
if not candidate_titles:
@@ -202,7 +201,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
202201
for title in tqdm(article_titles, desc=f" Downloading '{category}'"):
203202
safe_filename = "".join([c for c in title if c.isalpha() or c.isdigit() or c==' ']).rstrip() + ".txt"
204203
file_path = os.path.join(category_path, safe_filename)
205-
204+
206205
if os.path.exists(file_path): continue
207206

208207
page = wiki_wiki.page(title)
@@ -213,7 +212,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
213212
saved_count += 1
214213
except Exception as e:
215214
logger.error(f" Could not save article '{title}': {e}")
216-
215+
217216
logger.info(f" Saved {saved_count} new articles for '{category}'.")
218217

219218
logger.info("\n--- Wikipedia Data Download Complete ---")
@@ -313,21 +312,21 @@ def _load_jailbreak_dataset(self):
313312
logger.info(f"Total loaded jailbreak samples: {len(all_texts)}")
314313
unique_labels = sorted(list(set(all_labels)))
315314
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
316-
315+
317316
logger.info(f"Jailbreak label distribution: {dict(sorted([(label, all_labels.count(label)) for label in set(all_labels)], key=lambda x: x[1], reverse=True))}")
318317
logger.info(f"Total unique jailbreak labels: {len(unique_labels)}")
319-
318+
320319
label_indices = [label_to_idx[label] for label in all_labels]
321320
samples = list(zip(all_texts, label_indices))
322-
321+
323322
self.jailbreak_label_mapping = {"label_to_idx": label_to_idx, "idx_to_label": {idx: label for label, idx in label_to_idx.items()}}
324323
return samples
325324
except Exception as e:
326325
import traceback
327326
logger.error(f"Failed to load jailbreak datasets: {e}")
328327
traceback.print_exc()
329328
return []
330-
329+
331330
def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings):
332331
os.makedirs(checkpoint_dir, exist_ok=True)
333332
latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
@@ -371,7 +370,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
371370
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
372371
for step, batch in pbar:
373372
if steps_to_skip > 0 and step < steps_to_skip: continue
374-
373+
375374
optimizer.zero_grad()
376375
outputs = self.model(
377376
input_ids=batch["input_ids"].to(self.device),
@@ -388,7 +387,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
388387
optimizer.step(); scheduler.step(); global_step += 1
389388
if global_step > 0 and global_step % save_steps == 0:
390389
self._save_checkpoint(epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
391-
390+
392391
if val_loader: self.evaluate(val_loader)
393392
self._save_checkpoint(epoch + 1, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
394393
steps_to_skip = 0
@@ -442,13 +441,13 @@ def main():
442441
logger.info("--- Starting Model Training ---")
443442

444443
task_configs, label_mappings, checkpoint_to_load = {}, {}, None
445-
444+
446445
if args.resume:
447446
latest_checkpoint_path = os.path.join(args.checkpoint_dir, 'latest_checkpoint.pt')
448447
if os.path.exists(latest_checkpoint_path):
449448
logger.info(f"Resuming training from checkpoint: {latest_checkpoint_path}")
450449
checkpoint_to_load = torch.load(latest_checkpoint_path, map_location=device)
451-
450+
452451
task_configs = checkpoint_to_load.get('task_configs')
453452
label_mappings = checkpoint_to_load.get('label_mappings')
454453

@@ -460,7 +459,7 @@ def main():
460459
task_configs = {}
461460
else:
462461
logger.info("Loaded model configuration from checkpoint.")
463-
462+
464463
else:
465464
logger.warning(f"Resume flag is set, but no checkpoint found in '{args.checkpoint_dir}'. Starting fresh run.")
466465
args.resume = False
@@ -475,12 +474,12 @@ def main():
475474
task_configs["pii"] = {"num_classes": len(label_mappings["pii"]["label_mapping"]["label_to_idx"]), "weight": 3.0}
476475
if "jailbreak" in label_mappings:
477476
task_configs["jailbreak"] = {"num_classes": len(label_mappings["jailbreak"]["label_mapping"]["label_to_idx"]), "weight": 2.0}
478-
477+
479478
if not task_configs:
480479
logger.error("No tasks configured. Exiting."); return
481480

482481
logger.info(f"Final task configurations: {task_configs}")
483-
482+
484483
model = MultitaskBertModel(base_model_name, task_configs)
485484

486485
if args.resume and checkpoint_to_load:
@@ -496,9 +495,9 @@ def main():
496495
active_label_mappings = label_mappings if (args.resume and label_mappings) else final_label_mappings
497496

498497
trainer = MultitaskTrainer(model, tokenizer, task_configs, device)
499-
498+
500499
logger.info(f"Total training samples: {len(train_samples)}")
501-
500+
502501
trainer.train(
503502
train_samples, val_samples, active_label_mappings,
504503
num_epochs=10, batch_size=16,
@@ -507,7 +506,7 @@ def main():
507506
save_steps=args.save_steps,
508507
checkpoint_to_load=checkpoint_to_load
509508
)
510-
509+
511510
trainer.save_model(output_path)
512511
with open(os.path.join(output_path, "label_mappings.json"), "w") as f:
513512
json.dump(active_label_mappings, f, indent=2)

0 commit comments

Comments
 (0)