11import json
22import logging
33import os
4- from collections import defaultdic
4+ from collections import defaultdict
55from pathlib import Path
66import random
77import argparse
1414import requests
1515import torch
1616import torch .nn as nn
17- from datasets import load_datase
18- from sklearn .model_selection import train_test_spli
19- from torch .utils .data import DataLoader , Datase
17+ from datasets import load_dataset
18+ from sklearn .model_selection import train_test_split
19+ from torch .utils .data import DataLoader , Dataset
2020from transformers import AutoModel , AutoTokenizer , get_linear_schedule_with_warmup
2121import wikipediaapi
2222
3737
3838
3939WIKI_DATA_DIR = "wikipedia_data"
40- ARTICLES_PER_CATEGORY = 10000
40+ ARTICLES_PER_CATEGORY = 10000
4141PAGEVIEWS_API_USER_AGENT = "MyMultitaskProject/1.0 ([email protected] ) requests/2.0" 4242
4343
@@ -90,7 +90,7 @@ def __init__(self, model, tokenizer, task_configs, device="cuda"):
9090 self .device = device
9191 self .jailbreak_label_mapping = None
9292 self .loss_fns = {task : nn .CrossEntropyLoss () for task in task_configs if task_configs }
93-
93+
9494 self .EXCLUDED_CATEGORIES = {
9595 "fiction" , "culture" , "arts" , "comics" , "media" , "entertainment" ,
9696 "people" , "religion" , "sports" , "society" , "geography" , "history"
@@ -104,16 +104,16 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
104104 if visited_categories is None :
105105 visited_categories = set ()
106106
107- if depth >= max_depth or len (visited_categories ) > max_articles * 2 :
107+ if depth >= max_depth or len (visited_categories ) > max_articles * 2 :
108108 return []
109-
109+
110110 articles = []
111111 visited_categories .add (category_page .title )
112-
112+
113113 members = list (category_page .categorymembers .values ())
114114 for member in members :
115115 if len (articles ) >= max_articles : break
116-
116+
117117 if member .ns == wikipediaapi .Namespace .CATEGORY and member .title not in visited_categories :
118118 member_title_lower = member .title .lower ()
119119 if not any (excluded in member_title_lower for excluded in self .EXCLUDED_CATEGORIES ):
@@ -126,7 +126,7 @@ def _get_wiki_articles_recursive(self, category_page, max_articles, visited_cate
126126 ))
127127 elif member .ns == wikipediaapi .Namespace .MAIN :
128128 articles .append (member .title )
129-
129+
130130 return list (set (articles ))
131131
132132 def _get_pageviews (self , article_title , session , start_date , end_date ):
@@ -153,7 +153,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
153153 logger .info ("--- Starting Wikipedia Data Download and Preparation ---" )
154154 os .makedirs (data_dir , exist_ok = True )
155155 wiki_wiki = wikipediaapi .
Wikipedia (
'MyMultitaskProject ([email protected] )' ,
'en' )
156-
156+
157157 end_date = datetime .utcnow ()
158158 start_date = end_date - timedelta (days = 90 )
159159 end_date_str = end_date .strftime ('%Y%m%d' )
@@ -172,9 +172,9 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
172172
173173 num_candidates = articles_per_category * candidate_multiplier
174174 logger .info (f" Recursively searching for up to { num_candidates } candidate articles (max depth=4)..." )
175-
175+
176176 candidate_titles = self ._get_wiki_articles_recursive (cat_page , num_candidates )
177-
177+
178178 logger .info (f" Found { len (candidate_titles )} unique candidate articles." )
179179
180180 if not candidate_titles :
@@ -202,7 +202,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
202202 for title in tqdm (article_titles , desc = f" Downloading '{ category } '" ):
203203 safe_filename = "" .join ([c for c in title if c .isalpha () or c .isdigit () or c == ' ' ]).rstrip () + ".txt"
204204 file_path = os .path .join (category_path , safe_filename )
205-
205+
206206 if os .path .exists (file_path ): continue
207207
208208 page = wiki_wiki .page (title )
@@ -213,7 +213,7 @@ def download_wikipedia_articles(self, categories, articles_per_category, data_di
213213 saved_count += 1
214214 except Exception as e :
215215 logger .error (f" Could not save article '{ title } ': { e } " )
216-
216+
217217 logger .info (f" Saved { saved_count } new articles for '{ category } '." )
218218
219219 logger .info ("\n --- Wikipedia Data Download Complete ---" )
@@ -303,7 +303,7 @@ def _load_jailbreak_dataset(self):
303303 # Using the 'en' (English) configuration
304304 jb_dataset2 = load_dataset ("Babelscape/ALERT" , "alert" )
305305 texts2 = jb_dataset2 ['test' ]['prompt' ]
306- # Prefix labels to avoid conflicts with the other datase
306+ # Prefix labels to avoid conflicts with the other dataset
307307 labels2 = ["ALERT:" + str (l ) for l in jb_dataset2 ['test' ]['category' ]]
308308 all_texts .extend (texts2 )
309309 all_labels .extend (labels2 )
@@ -313,21 +313,21 @@ def _load_jailbreak_dataset(self):
313313 logger .info (f"Total loaded jailbreak samples: { len (all_texts )} " )
314314 unique_labels = sorted (list (set (all_labels )))
315315 label_to_idx = {label : idx for idx , label in enumerate (unique_labels )}
316-
316+
317317 logger .info (f"Jailbreak label distribution: { dict (sorted ([(label , all_labels .count (label )) for label in set (all_labels )], key = lambda x : x [1 ], reverse = True ))} " )
318318 logger .info (f"Total unique jailbreak labels: { len (unique_labels )} " )
319-
319+
320320 label_indices = [label_to_idx [label ] for label in all_labels ]
321321 samples = list (zip (all_texts , label_indices ))
322-
322+
323323 self .jailbreak_label_mapping = {"label_to_idx" : label_to_idx , "idx_to_label" : {idx : label for label , idx in label_to_idx .items ()}}
324324 return samples
325325 except Exception as e :
326326 import traceback
327327 logger .error (f"Failed to load jailbreak datasets: { e } " )
328328 traceback .print_exc ()
329329 return []
330-
330+
331331 def _save_checkpoint (self , epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings ):
332332 os .makedirs (checkpoint_dir , exist_ok = True )
333333 latest_checkpoint_path = os .path .join (checkpoint_dir , 'latest_checkpoint.pt' )
@@ -343,7 +343,7 @@ def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_
343343 torch .save (state , latest_checkpoint_path )
344344 logger .info (f"Checkpoint saved for step { global_step } at { latest_checkpoint_path } " )
345345
346- def train (self , train_samples , val_samples , label_mappings , num_epochs = 3 , batch_size = 16 , learning_rate = 2e-5 ,
346+ def train (self , train_samples , val_samples , label_mappings , num_epochs = 3 , batch_size = 16 , learning_rate = 2e-5 ,
347347 checkpoint_dir = 'checkpoints' , resume = False , save_steps = 500 , checkpoint_to_load = None ):
348348 train_dataset = MultitaskDataset (train_samples , self .tokenizer )
349349 train_loader = DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
@@ -372,7 +372,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
372372 pbar = tqdm (enumerate (train_loader ), total = len (train_loader ), desc = f"Epoch { epoch + 1 } " )
373373 for step , batch in pbar :
374374 if steps_to_skip > 0 and step < steps_to_skip : continue
375-
375+
376376 optimizer .zero_grad ()
377377 outputs = self .model (
378378 input_ids = batch ["input_ids" ].to (self .device ),
@@ -383,13 +383,13 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
383383 task_logits = outputs [task_name ][i : i + 1 ]
384384 task_label = batch ["label" ][i : i + 1 ].to (self .device )
385385 task_weight = self .task_configs [task_name ].get ("weight" , 1.0 )
386- batch_loss += self .loss_fns [task_name ](task_logits , task_label ) * task_weigh
386+ batch_loss += self .loss_fns [task_name ](task_logits , task_label ) * task_weight
387387 batch_loss .backward ()
388388 torch .nn .utils .clip_grad_norm_ (self .model .parameters (), 1.0 )
389389 optimizer .step (); scheduler .step (); global_step += 1
390390 if global_step > 0 and global_step % save_steps == 0 :
391391 self ._save_checkpoint (epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
392-
392+
393393 if val_loader : self .evaluate (val_loader )
394394 self ._save_checkpoint (epoch + 1 , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
395395 steps_to_skip = 0
@@ -443,13 +443,13 @@ def main():
443443 logger .info ("--- Starting Model Training ---" )
444444
445445 task_configs , label_mappings , checkpoint_to_load = {}, {}, None
446-
446+
447447 if args .resume :
448448 latest_checkpoint_path = os .path .join (args .checkpoint_dir , 'latest_checkpoint.pt' )
449449 if os .path .exists (latest_checkpoint_path ):
450450 logger .info (f"Resuming training from checkpoint: { latest_checkpoint_path } " )
451451 checkpoint_to_load = torch .load (latest_checkpoint_path , map_location = device )
452-
452+
453453 task_configs = checkpoint_to_load .get ('task_configs' )
454454 label_mappings = checkpoint_to_load .get ('label_mappings' )
455455
@@ -458,10 +458,10 @@ def main():
458458 logger .warning ("Cannot safely resume. Starting a fresh training run." )
459459 args .resume = False
460460 checkpoint_to_load = None
461- task_configs = {}
461+ task_configs = {}
462462 else :
463463 logger .info ("Loaded model configuration from checkpoint." )
464-
464+
465465 else :
466466 logger .warning (f"Resume flag is set, but no checkpoint found in '{ args .checkpoint_dir } '. Starting fresh run." )
467467 args .resume = False
@@ -476,12 +476,12 @@ def main():
476476 task_configs ["pii" ] = {"num_classes" : len (label_mappings ["pii" ]["label_mapping" ]["label_to_idx" ]), "weight" : 3.0 }
477477 if "jailbreak" in label_mappings :
478478 task_configs ["jailbreak" ] = {"num_classes" : len (label_mappings ["jailbreak" ]["label_mapping" ]["label_to_idx" ]), "weight" : 2.0 }
479-
479+
480480 if not task_configs :
481481 logger .error ("No tasks configured. Exiting." ); return
482482
483483 logger .info (f"Final task configurations: { task_configs } " )
484-
484+
485485 model = MultitaskBertModel (base_model_name , task_configs )
486486
487487 if args .resume and checkpoint_to_load :
@@ -497,9 +497,9 @@ def main():
497497 active_label_mappings = label_mappings if (args .resume and label_mappings ) else final_label_mappings
498498
499499 trainer = MultitaskTrainer (model , tokenizer , task_configs , device )
500-
500+
501501 logger .info (f"Total training samples: { len (train_samples )} " )
502-
502+
503503 trainer .train (
504504 train_samples , val_samples , active_label_mappings ,
505505 num_epochs = 10 , batch_size = 16 ,
@@ -508,7 +508,7 @@ def main():
508508 save_steps = args .save_steps ,
509509 checkpoint_to_load = checkpoint_to_load
510510 )
511-
511+
512512 trainer .save_model (output_path )
513513 with open (os .path .join (output_path , "label_mappings.json" ), "w" ) as f :
514514 json .dump (active_label_mappings , f , indent = 2 )
0 commit comments