Skip to content

Commit a3b6814

Browse files
committed
update pipeline
1 parent b0c67e2 commit a3b6814

File tree

9 files changed

+325
-291
lines changed

9 files changed

+325
-291
lines changed

.gitignore

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,24 @@ cython_debug/
161161
# and can be added to the global gitignore or merged into this file. For a more nuclear
162162
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
163163
#.idea/
164+
#
165+
--exclude-from=.git/info/exclude
166+
# Lines that start with '#' are comments.
167+
# For a project mostly in C, the following would be a good set of
168+
# exclude patterns (uncomment them if you want to use them):
169+
# *.[oa]
170+
# *~
171+
#
172+
CTTI*
173+
supplementary*
174+
*GNews*
175+
*.zip
176+
labeling/*/*
177+
clinical-trial-outcome-prediction*
178+
*.csv
179+
news_headlines/nct_news_logs/*
180+
manual_labels/*
181+
baselines/data/
182+
baselines/*.pt
183+
*PyTrial*
184+
ablations/

labeling/lfs.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def get_lfs(lf_each_thresh_path,
260260
LINKAGE_PATH,
261261
LABELS_AND_TICKERS_PATH,
262262
STUDIES_WITH_NEWS_PATH,
263-
no_hint=False):
263+
no_hint=False,
264+
skip_list=None):
264265
lf_thresh_df = pd.read_csv(lf_each_thresh_path, low_memory=False).sort_values(['lf', 'phase','acc'], ascending=False).astype(str)
265266
lf_thresh_df['best_thresh'] = 0
266267
for lf in lf_thresh_df['lf'].unique():
@@ -282,6 +283,13 @@ def get_lfs(lf_each_thresh_path,
282283

283284
known_lfs_list = [hint_lf,hint_lf,hint_lf, status_lf,status_lf, gpt_lf,gpt_lf, linkage_lf,linkage_lf, stock_price_lf, results_reported_lf, new_headlines_lf, pvalues_lf]
284285
df_names = ['hint_train', 'hint_train2', 'hint_train3', 'status','status2', 'gpt','gpt2', 'linkage','linkage2', 'stock_price', 'results_reported', 'new_headlines', 'pvalues']
286+
if skip_list is not None:
287+
inds = []
288+
for i, name in enumerate(df_names):
289+
if name not in skip_list:
290+
inds.append(i)
291+
df_names = [df_names[i] for i in inds]
292+
known_lfs_list = [known_lfs_list[i] for i in inds]
285293
phase_dfs = []
286294
for phase in ['1', '2', '3']:
287295
phase_lfs = known_lfs_list.copy()
@@ -326,8 +334,15 @@ def get_lfs(lf_each_thresh_path,
326334
parser.add_argument('--CTO_GOLD_PATH', type=str, default='/srv/local/data/CTO/outcome_labels/final_cto_labels_2020_2024.csv"')
327335
parser.add_argument('--label_mode', type=str, default='DP')
328336
parser.add_argument('--get_thresholds', type=bool, default=False)
337+
parser.add_argument('--SAVE_PATH', type=str, default="./")
338+
parser.add_argument('--SKIP_LIST', type=str, default=None, help="List of lfs to skip, e.g. ['hint_train', 'status']")
329339
args = parser.parse_args()
330340
print(args)
341+
if args.SKIP_LIST is not None:
342+
args.SKIP_LIST = eval(args.SKIP_LIST)
343+
assert isinstance(args.SKIP_LIST, list), "SKIP_LIST should be a list of strings"
344+
assert all(isinstance(x, str) for x in args.SKIP_LIST), "SKIP_LIST should contain only strings"
345+
print(type(args.SKIP_LIST), args.SKIP_LIST)
331346

332347
cto_gold = pd.read_csv(args.CTO_GOLD_PATH)
333348
cto_gold.rename(columns={'labels': 'label'}, inplace=True)
@@ -409,15 +424,16 @@ def get_lfs(lf_each_thresh_path,
409424
df.to_csv(args.LF_EACH_THRESH_PATH, index=False)
410425

411426
# ==== load best thresholds ====
412-
no_hint = True if args.label_mode != 'DP' else False
427+
no_hint = True if args.label_mode != 'DP' else False # do not compute thresholds if not using DP
413428
df_list = get_lfs(lf_each_thresh_path=args.LF_EACH_THRESH_PATH,
414429
path=args.CTTI_PATH,
415430
HINT_PATH=args.HINT_PATH,
416431
GPT_PATH=args.GPT_PATH,
417432
LINKAGE_PATH=args.LINKAGE_PATH,
418433
LABELS_AND_TICKERS_PATH=args.LABELS_AND_TICKERS_PATH,
419434
STUDIES_WITH_NEWS_PATH=args.STUDIES_WITH_NEWS_PATH,
420-
no_hint=no_hint)
435+
no_hint=no_hint,
436+
skip_list=args.SKIP_LIST)
421437

422438

423439
# ==== fit dp ====
@@ -469,6 +485,10 @@ def get_lfs(lf_each_thresh_path,
469485
label_model.fit(L[:,3:], class_balance=[1-positive_prop, positive_prop], seed=0, lr=lrs[i], n_epochs=300)
470486
label_model_pred_proba = label_model.predict_proba(L[:,3:])[:,1]
471487
label_model_pred = label_model.predict(L[:,3:])
488+
elif args.label_mode == 'MV':
489+
label_model = MajorityLabelVoter(cardinality=2)
490+
label_model_pred_proba = label_model.predict_proba(L)[:,1]
491+
label_model_pred = label_model.predict(L)
472492

473493
# apply status lf
474494
status_lf = lf_status(path=args.CTTI_PATH)
@@ -486,6 +506,7 @@ def get_lfs(lf_each_thresh_path,
486506
print(df2['pred'].value_counts())
487507

488508
df2['pred_proba'] = df2['pred']
509+
df2['pred_proba'] = df2['pred_proba'].astype(float)
489510
mask = df2['pred'] == -1
490511

491512
# apply labelmodel pred where pred == -1
@@ -519,7 +540,10 @@ def get_lfs(lf_each_thresh_path,
519540
cohen_kappa_score(combined['label'], combined['pred']))
520541

521542
# save results
522-
all_combined_full[0].to_csv(f'phase1_{args.label_mode.lower()}.csv', index=False)
523-
all_combined_full[1].to_csv(f'phase2_{args.label_mode.lower()}.csv', index=False)
524-
all_combined_full[2].to_csv(f'phase3_{args.label_mode.lower()}.csv', index=False)
543+
if os.path.exists(args.SAVE_PATH) == False:
544+
os.makedirs(args.SAVE_PATH)
545+
combined.to_csv(os.path.join(args.SAVE_PATH, f'combined_eval_{args.label_mode.lower()}.csv'), index=False)
546+
all_combined_full[0].to_csv(os.path.join(args.SAVE_PATH, f'phase1_{args.label_mode.lower()}.csv'), index=False)
547+
all_combined_full[1].to_csv(os.path.join(args.SAVE_PATH, f'phase2_{args.label_mode.lower()}.csv'), index=False)
548+
all_combined_full[2].to_csv(os.path.join(args.SAVE_PATH, f'phase3_{args.label_mode.lower()}.csv'), index=False)
525549

pipeline.sh

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,70 @@
1-
DATA_PATH=/srv/local/data/CTO/CTTI_new
2-
SAVE_PATH=/srv/local/data/CTO
1+
DATA_PATH=/shared/rsaas/CTO/ctti_11_06_2025
2+
SAVE_PATH=/shared/rsaas/CTO/andy
33

44

55
# # Downloading CTTI new data
66
# echo "Downloading CTTI new data"
77
# python download_ctti.py --save_path $SAVE_PATH
88

99

10-
11-
# # # Getting LLM predictions on Pubmed data
12-
echo "Getting LLM predictions on Pubmed data"
13-
cd llm_prediction_on_pubmed
14-
10+
# ========================= Getting LLM predictions on Pubmed data =========================
11+
# echo "Getting LLM predictions on Pubmed data"
1512
# echo "Extracting and Updating Pubmed data"
16-
# python extract_pubmed_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
13+
# python ./llm_prediction_on_pubmed/extract_pubmed_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
1714
# echo "Search Pubmed and extract abstracts"
18-
# python extract_pubmed_abstracts_through_search.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
15+
# python ./llm_prediction_on_pubmed/extract_pubmed_abstracts_through_search.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
1916
# echo "Retrieving top 2 relevant abstracts"
20-
# python retrieve_top2_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
17+
# python ./llm_prediction_on_pubmed/retrieve_top2_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
2118
# echo "Obtaining LLM predictions"
22-
# python get_llm_predictions.py --save_path $SAVE_PATH --azure #--dev
23-
# python clean_and_extract_final_outcomes.py --save_path $SAVE_PATH
19+
# python ./llm_prediction_on_pubmed/get_llm_predictions.py --save_path $SAVE_PATH --azure #--dev
20+
# python ./llm_prediction_on_pubmed/clean_and_extract_final_outcomes.py --save_path $SAVE_PATH
2421

2522

26-
# # # # Getting Clinical Trial Linkage
23+
# ========================= Getting Clinical Trial Linkage ========================
2724
# echo "Getting Clinical Trial Linkage"
28-
cd ..
29-
cd clinical_trial_linkage
3025

3126
# echo "Downloading FDA orange book and drug code dictionary"
32-
# python download_data.py --save_path $SAVE_PATH # centralize the links in the .sh
27+
# python ./clinical_trial_linkage/download_data.py --save_path $SAVE_PATH # centralize the links in the .sh
3328
# echo "Processing FDA orange book and drug code dictionary"
34-
# python process_drugbank.py --save_path $SAVE_PATH
35-
# python create_drug_mapping.py --save_path $SAVE_PATH
29+
# python ./clinical_trial_linkage/process_drugbank.py --save_path $SAVE_PATH
30+
# python ./clinical_trial_linkage/create_drug_mapping.py --save_path $SAVE_PATH
3631

3732
# echo "Extracting trial info and trial embeddings"
38-
# python extract_trial_info.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
39-
# python get_embedding_for_trial_linkage.py --save_path $SAVE_PATH --num_workers 8 --gpu_ids 0,1,2 #--dev
33+
# python ./clinical_trial_linkage/extract_trial_info.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
34+
# python ./clinical_trial_linkage/get_embedding_for_trial_linkage.py --save_path $SAVE_PATH --num_workers 8 --gpu_ids 0,1,2 #--dev
4035

4136

4237
# echo 'Linking Clinical Trials across phases'
4338
# echo 'Phase 4'
44-
# python create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase4' --num_workers 1 --gpu_ids 4 #--dev
39+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase4' --num_workers 1 --gpu_ids 4 #--dev
4540
# echo 'Phase 3'
46-
# python create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase3' --num_workers 1 --gpu_ids 4 #--dev
41+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase3' --num_workers 1 --gpu_ids 4 #--dev
4742
# echo 'Phase 2/ Phase 3'
48-
# python create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2/phase3' --num_workers 1 --gpu_ids 4 #--dev
43+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2/phase3' --num_workers 1 --gpu_ids 4 #--dev
4944
# echo 'Phase 2'
50-
# python create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2' --num_workers 1 --gpu_ids 4 #--dev
45+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2' --num_workers 1 --gpu_ids 4 #--dev
5146

5247
# echo 'Extract outcomes from Clinical Trial Linkage'
53-
# python extract_outcome_from_trial_linkage.py --save_path $SAVE_PATH
48+
# python ./clinical_trial_linkage/extract_outcome_from_trial_linkage.py --save_path $SAVE_PATH
5449
# echo 'Matching with FDA orange book'
55-
# python match_fda_approvals.py --save_path $SAVE_PATH #--dev
56-
57-
50+
# python ./clinical_trial_linkage/match_fda_approvals.py --save_path $SAVE_PATH #--dev
5851

59-
# News
6052

53+
# ========================= News ========================
54+
# skip for now due to quota limits
55+
# python ./news_headlines/get_news.py --mode=get_news --continue_from_prev_log=True --CTTI_PATH=$DATA_PATH --SENTIMENT_MODEL="cardiffnlp/twitter-roberta-base-sentiment-latest" --SAVE_NEWS_LOG_PATH=$SAVE_PATH/news_headlines/ --SAVE_STUDY_NEWS_PATH=$SAVE_PATH/news.csv
6156

62-
#Stock prices
57+
# # ========================= Stock prices =======================
58+
# echo "Updating stock prices and computing slopes"
59+
# # Ensure tickers.csv exists under SAVE_PATH (adjust path as needed)
60+
# python ./stock_price/get_stocks.py --CTTI_PATH $DATA_PATH --TICKERS_PATH ./stock_price/tickers.csv --SAVE_STOCKS_PATH $SAVE_PATH/stock_data.csv.zip --SAVE_STOCKS_SLOPES_PATH $SAVE_PATH/stock_labels.csv
6361

62+
# ========================= Amendments ========================
63+
python ./stock_price/scrape_amendments.py --CTTI_PATH $DATA_PATH --SAVE_PATH $SAVE_PATH/amendment_counts.csv --years 2
6464

65-
# Labeling
66-
# echo "Copy all labeling results to the labeling folder"
67-
cd ..
68-
python arrange_labels.py --save_path $SAVE_PATH
65+
# # ========================= Lpdate :abels =================
66+
# python labeling/lfs.py --get_thresholds=True --LF_EACH_THRESH_PATH=$LF_EACH_THRESH_PATH --CTTI_PATH=$CTTI_PATH --HINT_PATH=$HINT_PATH --LABELS_AND_TICKERS_PATH=$LABELS_AND_TICKERS_PATH --GPT_PATH=$GPT_PATH --LINKAGE_PATH=$LINKAGE_PATH --STUDIES_WITH_NEWS_PATH=$STUDIES_WITH_NEWS_PATH --label_mode=$label_mode --CTO_GOLD_PATH=$CTO_GOLD_PATH --SAVE_PATH=$SAVE_PATH --SKIP_LIST="['new_headlines']"
6967

70-
# limit it to drugs
68+
# # # Labeling
69+
# # cd ..
70+
# # python arrange_labels.py --save_path $SAVE_PATH

pipeline_temp_stocks.sh

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
DATA_PATH=/shared/rsaas/CTO/ctti_11_06_2025
2+
SAVE_PATH=/shared/rsaas/CTO/andy
3+
4+
5+
# # Downloading CTTI new data
6+
# echo "Downloading CTTI new data"
7+
# python download_ctti.py --save_path $SAVE_PATH
8+
9+
10+
# ========================= Getting LLM predictions on Pubmed data =========================
11+
# echo "Getting LLM predictions on Pubmed data"
12+
# echo "Extracting and Updating Pubmed data"
13+
# python ./llm_prediction_on_pubmed/extract_pubmed_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
14+
# echo "Search Pubmed and extract abstracts"
15+
# python ./llm_prediction_on_pubmed/extract_pubmed_abstracts_through_search.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
16+
# echo "Retrieving top 2 relevant abstracts"
17+
# python ./llm_prediction_on_pubmed/retrieve_top2_abstracts.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
18+
# echo "Obtaining LLM predictions"
19+
# python ./llm_prediction_on_pubmed/get_llm_predictions.py --save_path $SAVE_PATH --azure #--dev
20+
# python ./llm_prediction_on_pubmed/clean_and_extract_final_outcomes.py --save_path $SAVE_PATH
21+
22+
23+
# ========================= Getting Clinical Trial Linkage ========================
24+
# echo "Getting Clinical Trial Linkage"
25+
26+
# echo "Downloading FDA orange book and drug code dictionary"
27+
# python ./clinical_trial_linkage/download_data.py --save_path $SAVE_PATH # centralize the links in the .sh
28+
# echo "Processing FDA orange book and drug code dictionary"
29+
# python ./clinical_trial_linkage/process_drugbank.py --save_path $SAVE_PATH
30+
# python ./clinical_trial_linkage/create_drug_mapping.py --save_path $SAVE_PATH
31+
32+
# echo "Extracting trial info and trial embeddings"
33+
# python ./clinical_trial_linkage/extract_trial_info.py --data_path $DATA_PATH --save_path $SAVE_PATH #--dev
34+
# python ./clinical_trial_linkage/get_embedding_for_trial_linkage.py --save_path $SAVE_PATH --num_workers 8 --gpu_ids 0,1,2 #--dev
35+
36+
37+
# echo 'Linking Clinical Trials across phases'
38+
# echo 'Phase 4'
39+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase4' --num_workers 1 --gpu_ids 4 #--dev
40+
# echo 'Phase 3'
41+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase3' --num_workers 1 --gpu_ids 4 #--dev
42+
# echo 'Phase 2/ Phase 3'
43+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2/phase3' --num_workers 1 --gpu_ids 4 #--dev
44+
# echo 'Phase 2'
45+
# python ./clinical_trial_linkage/create_trial_linkage.py --save_path $SAVE_PATH --target_phase 'phase2' --num_workers 1 --gpu_ids 4 #--dev
46+
47+
# echo 'Extract outcomes from Clinical Trial Linkage'
48+
# python ./clinical_trial_linkage/extract_outcome_from_trial_linkage.py --save_path $SAVE_PATH
49+
# echo 'Matching with FDA orange book'
50+
# python ./clinical_trial_linkage/match_fda_approvals.py --save_path $SAVE_PATH #--dev
51+
52+
53+
# ========================= News ========================
54+
# skip for now due to quota limits
55+
# python ./news_headlines/get_news.py --mode=get_news --continue_from_prev_log=True --CTTI_PATH=$DATA_PATH --SENTIMENT_MODEL="cardiffnlp/twitter-roberta-base-sentiment-latest" --SAVE_NEWS_LOG_PATH=$SAVE_PATH/news_headlines/ --SAVE_STUDY_NEWS_PATH=$SAVE_PATH/news.csv
56+
57+
# ========================= Stock prices =======================
58+
echo "Updating stock prices and computing slopes"
59+
# Ensure tickers.csv exists under SAVE_PATH (adjust path as needed)
60+
python ./stock_price/get_stocks.py --CTTI_PATH $DATA_PATH --TICKERS_PATH ./stock_price/tickers.csv --SAVE_STOCKS_PATH $SAVE_PATH/stock_data.csv.zip --SAVE_STOCKS_SLOPES_PATH $SAVE_PATH/stock_labels.csv
61+
62+
# # ========================= Amendments ========================
63+
# python ./stock_price/scrape_amendments.py --CTTI_PATH $DATA_PATH --SAVE_PATH $SAVE_PATH/amendment_counts.csv --years 2
64+
65+
# # ========================= Lpdate :abels =================
66+
# python labeling/lfs.py --get_thresholds=True --LF_EACH_THRESH_PATH=$LF_EACH_THRESH_PATH --CTTI_PATH=$CTTI_PATH --HINT_PATH=$HINT_PATH --LABELS_AND_TICKERS_PATH=$LABELS_AND_TICKERS_PATH --GPT_PATH=$GPT_PATH --LINKAGE_PATH=$LINKAGE_PATH --STUDIES_WITH_NEWS_PATH=$STUDIES_WITH_NEWS_PATH --label_mode=$label_mode --CTO_GOLD_PATH=$CTO_GOLD_PATH --SAVE_PATH=$SAVE_PATH --SKIP_LIST="['new_headlines']"
67+
68+
# # # Labeling
69+
# # cd ..
70+
# # python arrange_labels.py --save_path $SAVE_PATH

stock_price/scrape_amendments.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
if __name__ == '__main__':
1717
parser = argparse.ArgumentParser()
1818
parser.add_argument('--CTTI_PATH', type=str, default='../CTTI/')
19+
parser.add_argument('--SAVE_PATH', type=str, default='./amendment_counts.csv')
20+
parser.add_argument("--years", type=int, default=2, help="Number of years back to scrape amendments for")
1921
args = parser.parse_args()
2022

2123
studies = pd.read_csv(os.path.join(args.CTTI_PATH, 'studies.txt'), sep='|')
@@ -28,12 +30,17 @@
2830

2931
studies = studies.dropna(subset=['phase'])
3032

33+
# select only trials with start date within the last `years` years
34+
current_year = time.localtime().tm_year
35+
studies['start_year'] = pd.to_datetime(studies['start_date'], errors='coerce').dt.year
36+
studies = studies[studies['start_year'] >= current_year - args.years]
37+
3138
chrome_options = Options()
3239
chrome_options.add_argument("--headless") #FOR DEBUG COMMENT OUT SO YOU CAN SEE WHAT YOU'RE DOING
3340
driver = webdriver.Firefox(options=chrome_options)
3441

3542
amendment_counts = []
36-
for i, nct in enumerate(tqdm(studies['nct_id'].iloc[59525:])):
43+
for i, nct in enumerate(tqdm(studies['nct_id'])):
3744
try:
3845
driver.get(f'https://clinicaltrials.gov/study/{nct}?tab=history')
3946
# driver.page_source # needs to be called before the next line
@@ -47,7 +54,7 @@
4754

4855
if i % 100 == 0:
4956
out_df = pd.DataFrame(amendment_counts, columns=['nct_id', 'amendment_count'])
50-
out_df.to_csv('./amendment_counts.csv', index=False)
57+
out_df.to_csv(os.path.join(args.SAVE_PATH), index=False)
5158
except Exception as e:
5259
print(f"Error for {nct}: {e}")
5360
# break

0 commit comments

Comments
 (0)