11import json
22import logging
33import os
4- from collections import defaultdict
5- from pathlib import Path
64import random
75import argparse
8- from tqdm import tqdm
96import time
10- from datetime import datetime , timedelta
117import urllib .parse
8+ from collections import defaultdict
9+ from pathlib import Path
10+ from datetime import datetime , timedelta
1211
1312import numpy as np
1413import requests
3736
3837
3938WIKI_DATA_DIR = "wikipedia_data"
40- ARTICLES_PER_CATEGORY = 10000
39+ ARTICLES_PER_CATEGORY = 10000
4140PAGEVIEWS_API_USER_AGENT = "MyMultitaskProject/1.0 ([email protected] ) requests/2.0" 4241
4342
@@ -90,7 +89,7 @@ def __init__(self, model, tokenizer, task_configs, device="cuda"):
9089 self .device = device
9190 self .jailbreak_label_mapping = None
9291 self .loss_fns = {task : nn .CrossEntropyLoss () for task in task_configs if task_configs }
93-
92+
9493 self .EXCLUDED_CATEGORIES = {
9594 "fiction" , "culture" , "arts" , "comics" , "media" , "entertainment" ,
9695 "people" , "religion" , "sports" , "society" , "geography" , "history"
@@ -104,16 +103,16 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
104103 if visited_categories is None :
105104 visited_categories = set ()
106105
107- if depth >= max_depth or len (visited_categories ) > max_articles * 2 :
106+ if depth >= max_depth or len (visited_categories ) > max_articles * 2 :
108107 return []
109-
108+
110109 articles = []
111110 visited_categories .add (category_page .title )
112-
111+
113112 members = list (category_page .categorymembers .values ())
114113 for member in members :
115114 if len (articles ) >= max_articles : break
116-
115+
117116 if member .ns == wikipediaapi .Namespace .CATEGORY and member .title not in visited_categories :
118117 member_title_lower = member .title .lower ()
119118 if not any (excluded in member_title_lower for excluded in self .EXCLUDED_CATEGORIES ):
@@ -126,7 +125,7 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
126125 ))
127126 elif member .ns == wikipediaapi .Namespace .MAIN :
128127 articles .append (member .title )
129-
128+
130129 return list (set (articles ))
131130
132131 def _get_pageviews (self , article_title , session , start_date , end_date ):
@@ -153,7 +152,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
153152 logger .info ("--- Starting Wikipedia Data Download and Preparation ---" )
154153 os .makedirs (data_dir , exist_ok = True )
155154 wiki_wiki = wikipediaapi .
Wikipedia (
'MyMultitaskProject ([email protected] )' ,
'en' )
156-
155+
157156 end_date = datetime .utcnow ()
158157 start_date = end_date - timedelta (days = 90 )
159158 end_date_str = end_date .strftime ('%Y%m%d' )
@@ -172,9 +171,9 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
172171
173172 num_candidates = articles_per_category * candidate_multiplier
174173 logger .info (f" Recursively searching for up to { num_candidates } candidate articles (max depth=4)..." )
175-
174+
176175 candidate_titles = self ._get_wiki_articles_recursive (cat_page , num_candidates )
177-
176+
178177 logger .info (f" Found { len (candidate_titles )} unique candidate articles." )
179178
180179 if not candidate_titles :
@@ -202,7 +201,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
202201 for title in tqdm (article_titles , desc = f" Downloading '{ category } '" ):
203202 safe_filename = "" .join ([c for c in title if c .isalpha () or c .isdigit () or c == ' ' ]).rstrip () + ".txt"
204203 file_path = os .path .join (category_path , safe_filename )
205-
204+
206205 if os .path .exists (file_path ): continue
207206
208207 page = wiki_wiki .page (title )
@@ -213,7 +212,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
213212 saved_count += 1
214213 except Exception as e :
215214 logger .error (f" Could not save article '{ title } ': { e } " )
216-
215+
217216 logger .info (f" Saved { saved_count } new articles for '{ category } '." )
218217
219218 logger .info ("\n --- Wikipedia Data Download Complete ---" )
@@ -313,21 +312,21 @@ def _load_jailbreak_dataset(self):
313312 logger .info (f"Total loaded jailbreak samples: { len (all_texts )} " )
314313 unique_labels = sorted (list (set (all_labels )))
315314 label_to_idx = {label : idx for idx , label in enumerate (unique_labels )}
316-
315+
317316 logger .info (f"Jailbreak label distribution: { dict (sorted ([(label , all_labels .count (label )) for label in set (all_labels )], key = lambda x : x [1 ], reverse = True ))} " )
318317 logger .info (f"Total unique jailbreak labels: { len (unique_labels )} " )
319-
318+
320319 label_indices = [label_to_idx [label ] for label in all_labels ]
321320 samples = list (zip (all_texts , label_indices ))
322-
321+
323322 self .jailbreak_label_mapping = {"label_to_idx" : label_to_idx , "idx_to_label" : {idx : label for label , idx in label_to_idx .items ()}}
324323 return samples
325324 except Exception as e :
326325 import traceback
327326 logger .error (f"Failed to load jailbreak datasets: { e } " )
328327 traceback .print_exc ()
329328 return []
330-
329+
331330 def _save_checkpoint (self , epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings ):
332331 os .makedirs (checkpoint_dir , exist_ok = True )
333332 latest_checkpoint_path = os .path .join (checkpoint_dir , 'latest_checkpoint.pt' )
@@ -371,7 +370,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
371370 pbar = tqdm (enumerate (train_loader ), total = len (train_loader ), desc = f"Epoch { epoch + 1 } " )
372371 for step , batch in pbar :
373372 if steps_to_skip > 0 and step < steps_to_skip : continue
374-
373+
375374 optimizer .zero_grad ()
376375 outputs = self .model (
377376 input_ids = batch ["input_ids" ].to (self .device ),
@@ -388,7 +387,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
388387 optimizer .step (); scheduler .step (); global_step += 1
389388 if global_step > 0 and global_step % save_steps == 0 :
390389 self ._save_checkpoint (epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
391-
390+
392391 if val_loader : self .evaluate (val_loader )
393392 self ._save_checkpoint (epoch + 1 , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
394393 steps_to_skip = 0
@@ -442,13 +441,13 @@ def main():
442441 logger .info ("--- Starting Model Training ---" )
443442
444443 task_configs , label_mappings , checkpoint_to_load = {}, {}, None
445-
444+
446445 if args .resume :
447446 latest_checkpoint_path = os .path .join (args .checkpoint_dir , 'latest_checkpoint.pt' )
448447 if os .path .exists (latest_checkpoint_path ):
449448 logger .info (f"Resuming training from checkpoint: { latest_checkpoint_path } " )
450449 checkpoint_to_load = torch .load (latest_checkpoint_path , map_location = device )
451-
450+
452451 task_configs = checkpoint_to_load .get ('task_configs' )
453452 label_mappings = checkpoint_to_load .get ('label_mappings' )
454453
@@ -460,7 +459,7 @@ def main():
460459 task_configs = {}
461460 else :
462461 logger .info ("Loaded model configuration from checkpoint." )
463-
462+
464463 else :
465464 logger .warning (f"Resume flag is set, but no checkpoint found in '{ args .checkpoint_dir } '. Starting fresh run." )
466465 args .resume = False
@@ -475,12 +474,12 @@ def main():
475474 task_configs ["pii" ] = {"num_classes" : len (label_mappings ["pii" ]["label_mapping" ]["label_to_idx" ]), "weight" : 3.0 }
476475 if "jailbreak" in label_mappings :
477476 task_configs ["jailbreak" ] = {"num_classes" : len (label_mappings ["jailbreak" ]["label_mapping" ]["label_to_idx" ]), "weight" : 2.0 }
478-
477+
479478 if not task_configs :
480479 logger .error ("No tasks configured. Exiting." ); return
481480
482481 logger .info (f"Final task configurations: { task_configs } " )
483-
482+
484483 model = MultitaskBertModel (base_model_name , task_configs )
485484
486485 if args .resume and checkpoint_to_load :
@@ -496,9 +495,9 @@ def main():
496495 active_label_mappings = label_mappings if (args .resume and label_mappings ) else final_label_mappings
497496
498497 trainer = MultitaskTrainer (model , tokenizer , task_configs , device )
499-
498+
500499 logger .info (f"Total training samples: { len (train_samples )} " )
501-
500+
502501 trainer .train (
503502 train_samples , val_samples , active_label_mappings ,
504503 num_epochs = 10 , batch_size = 16 ,
@@ -507,7 +506,7 @@ def main():
507506 save_steps = args .save_steps ,
508507 checkpoint_to_load = checkpoint_to_load
509508 )
510-
509+
511510 trainer .save_model (output_path )
512511 with open (os .path .join (output_path , "label_mappings.json" ), "w" ) as f :
513512 json .dump (active_label_mappings , f , indent = 2 )
0 commit comments