@@ -260,7 +260,8 @@ def get_lfs(lf_each_thresh_path,
260260 LINKAGE_PATH ,
261261 LABELS_AND_TICKERS_PATH ,
262262 STUDIES_WITH_NEWS_PATH ,
263- no_hint = False ):
263+ no_hint = False ,
264+ skip_list = None ):
264265 lf_thresh_df = pd .read_csv (lf_each_thresh_path , low_memory = False ).sort_values (['lf' , 'phase' ,'acc' ], ascending = False ).astype (str )
265266 lf_thresh_df ['best_thresh' ] = 0
266267 for lf in lf_thresh_df ['lf' ].unique ():
@@ -282,6 +283,13 @@ def get_lfs(lf_each_thresh_path,
282283
283284 known_lfs_list = [hint_lf ,hint_lf ,hint_lf , status_lf ,status_lf , gpt_lf ,gpt_lf , linkage_lf ,linkage_lf , stock_price_lf , results_reported_lf , new_headlines_lf , pvalues_lf ]
284285 df_names = ['hint_train' , 'hint_train2' , 'hint_train3' , 'status' ,'status2' , 'gpt' ,'gpt2' , 'linkage' ,'linkage2' , 'stock_price' , 'results_reported' , 'new_headlines' , 'pvalues' ]
286+ if skip_list is not None :
287+ inds = []
288+ for i , name in enumerate (df_names ):
289+ if name not in skip_list :
290+ inds .append (i )
291+ df_names = [df_names [i ] for i in inds ]
292+ known_lfs_list = [known_lfs_list [i ] for i in inds ]
285293 phase_dfs = []
286294 for phase in ['1' , '2' , '3' ]:
287295 phase_lfs = known_lfs_list .copy ()
@@ -326,8 +334,15 @@ def get_lfs(lf_each_thresh_path,
326334 parser .add_argument ('--CTO_GOLD_PATH' , type = str , default = '/srv/local/data/CTO/outcome_labels/final_cto_labels_2020_2024.csv"' )
327335 parser .add_argument ('--label_mode' , type = str , default = 'DP' )
328336 parser .add_argument ('--get_thresholds' , type = bool , default = False )
337+ parser .add_argument ('--SAVE_PATH' , type = str , default = "./" )
338+ parser .add_argument ('--SKIP_LIST' , type = str , default = None , help = "List of lfs to skip, e.g. ['hint_train', 'status']" )
329339 args = parser .parse_args ()
330340 print (args )
341+ if args .SKIP_LIST is not None :
342+ args .SKIP_LIST = eval (args .SKIP_LIST )
343+ assert isinstance (args .SKIP_LIST , list ), "SKIP_LIST should be a list of strings"
344+ assert all (isinstance (x , str ) for x in args .SKIP_LIST ), "SKIP_LIST should contain only strings"
345+ print (type (args .SKIP_LIST ), args .SKIP_LIST )
331346
332347 cto_gold = pd .read_csv (args .CTO_GOLD_PATH )
333348 cto_gold .rename (columns = {'labels' : 'label' }, inplace = True )
@@ -409,15 +424,16 @@ def get_lfs(lf_each_thresh_path,
409424 df .to_csv (args .LF_EACH_THRESH_PATH , index = False )
410425
411426 # ==== load best thresholds ====
412- no_hint = True if args .label_mode != 'DP' else False
427+ no_hint = True if args .label_mode != 'DP' else False # do not compute thresholds if not using DP
413428 df_list = get_lfs (lf_each_thresh_path = args .LF_EACH_THRESH_PATH ,
414429 path = args .CTTI_PATH ,
415430 HINT_PATH = args .HINT_PATH ,
416431 GPT_PATH = args .GPT_PATH ,
417432 LINKAGE_PATH = args .LINKAGE_PATH ,
418433 LABELS_AND_TICKERS_PATH = args .LABELS_AND_TICKERS_PATH ,
419434 STUDIES_WITH_NEWS_PATH = args .STUDIES_WITH_NEWS_PATH ,
420- no_hint = no_hint )
435+ no_hint = no_hint ,
436+ skip_list = args .SKIP_LIST )
421437
422438
423439 # ==== fit dp ====
@@ -469,6 +485,10 @@ def get_lfs(lf_each_thresh_path,
469485 label_model .fit (L [:,3 :], class_balance = [1 - positive_prop , positive_prop ], seed = 0 , lr = lrs [i ], n_epochs = 300 )
470486 label_model_pred_proba = label_model .predict_proba (L [:,3 :])[:,1 ]
471487 label_model_pred = label_model .predict (L [:,3 :])
488+ elif args .label_mode == 'MV' :
489+ label_model = MajorityLabelVoter (cardinality = 2 )
490+ label_model_pred_proba = label_model .predict_proba (L )[:,1 ]
491+ label_model_pred = label_model .predict (L )
472492
473493 # apply status lf
474494 status_lf = lf_status (path = args .CTTI_PATH )
@@ -486,6 +506,7 @@ def get_lfs(lf_each_thresh_path,
486506 print (df2 ['pred' ].value_counts ())
487507
488508 df2 ['pred_proba' ] = df2 ['pred' ]
509+ df2 ['pred_proba' ] = df2 ['pred_proba' ].astype (float )
489510 mask = df2 ['pred' ] == - 1
490511
491512 # apply labelmodel pred where pred == -1
@@ -519,7 +540,10 @@ def get_lfs(lf_each_thresh_path,
519540 cohen_kappa_score (combined ['label' ], combined ['pred' ]))
520541
521542 # save results
522- all_combined_full [0 ].to_csv (f'phase1_{ args .label_mode .lower ()} .csv' , index = False )
523- all_combined_full [1 ].to_csv (f'phase2_{ args .label_mode .lower ()} .csv' , index = False )
524- all_combined_full [2 ].to_csv (f'phase3_{ args .label_mode .lower ()} .csv' , index = False )
543+ if os .path .exists (args .SAVE_PATH ) == False :
544+ os .makedirs (args .SAVE_PATH )
545+ combined .to_csv (os .path .join (args .SAVE_PATH , f'combined_eval_{ args .label_mode .lower ()} .csv' ), index = False )
546+ all_combined_full [0 ].to_csv (os .path .join (args .SAVE_PATH , f'phase1_{ args .label_mode .lower ()} .csv' ), index = False )
547+ all_combined_full [1 ].to_csv (os .path .join (args .SAVE_PATH , f'phase2_{ args .label_mode .lower ()} .csv' ), index = False )
548+ all_combined_full [2 ].to_csv (os .path .join (args .SAVE_PATH , f'phase3_{ args .label_mode .lower ()} .csv' ), index = False )
525549
0 commit comments