Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,11 @@ This is a list of TODOs for the repository. If you are interested in contributin
- [ ] Curate higher quality instruction tuning and reasoning datasets for ELMs.
- [ ] Expand upon current naive distributed training setting to include more efficient and explicit distributed training strategies (i.e., data, tensor, context, pipeline, and expert parallelism as noted in [here](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=5d_parallelism_in_a_nutshell)).
- [x] Add option for data mixing.
- [x] Adjust feature selection for RAG.
- [x] Apply normalization to RAG database.
- [ ] Cross Dataset Ablation.
- [ ] Apply RAG with encoder methods
- [ ] Add "Only feature" retrieval option for RAG
- [x] For preprocessing, stratify based on patient, such that no overlapping patients between train and test.
- [x] Add official splits for benchmarking.
- [x] Upload to huggingface datasets and use huggingface datasets data loading in main.
Expand Down
33 changes: 17 additions & 16 deletions ecg_bench/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,23 @@ def get_args():
peft_group.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")

### Mode and Environment
mode_group = parser.add_argument_group("Mode and Environment")
mode_group.add_argument("--train", type=str, default=None, choices=["first", "second", "end2end"], help="Training mode")
mode_group.add_argument("--inference", type=str, default=None, choices=["second", "end2end"], help="Inference mode")
mode_group.add_argument("--post_train", action="store_true", default=None, help="Post-training mode")
mode_group.add_argument("--train_encoder", action="store_true", default=None, help="Train encoder too")
mode_group.add_argument("--interpret", action="store_true", default=None, help="Interpret mode")
mode_group.add_argument("--rag", action="store_true", default=None, help="RAG mode")
mode_group.add_argument("--rag_k", type=int, default=1, help="RAG k")
mode_group.add_argument("--rag_prompt_mode", type=str, default="system_prompt", choices=["system_prompt", "user_query"], help="How to integrate RAG results: in system prompt, user query")
mode_group.add_argument("--retrieval_base", type=str, default="combined", choices=["signal", "feature", "combined"], help="Retrieval base for similarity calculation")
mode_group.add_argument("--retrieved_information", type=str, default="combined", choices=["feature", "report", "combined"], help="Type of information to retrieve in output")
mode_group.add_argument("--load_rag_db", type = str, default = None, help = "Load a RAG database")
mode_group.add_argument("--load_rag_db_idx", type = str, default = None, help = "Load a RAG database index")
mode_group.add_argument("--dev", action="store_true", default=None, help="Development mode")
mode_group.add_argument("--log", action="store_true", default=None, help="Enable logging")

mode_group = parser.add_argument_group('Mode and Environment')
mode_group.add_argument('--train', type=str, default=None, choices=['first', 'second', 'end2end'], help='Training mode')
mode_group.add_argument('--inference', type=str, default=None, choices=['second', 'end2end'], help='Inference mode')
mode_group.add_argument('--post_train', action='store_true', default=None, help='Post-training mode')
mode_group.add_argument('--train_encoder', action='store_true', default=None, help='Train encoder too')
mode_group.add_argument('--interpret', action='store_true', default=None, help='Interpret mode')
mode_group.add_argument('--rag', action='store_true', default=None, help='RAG mode')
mode_group.add_argument('--rag_k', type=int, default=1, help='RAG k')
mode_group.add_argument('--rag_prompt_mode', type=str, default='system_prompt', choices=['system_prompt', 'user_query'], help='How to integrate RAG results: in system prompt, user query')
mode_group.add_argument('--retrieval_base', type=str, default='combined', choices=['signal', 'feature', 'combined'], help='Retrieval base for similarity calculation')
mode_group.add_argument('--retrieved_information', type=str, default='combined', choices=['feature', 'report', 'combined'], help='Type of information to retrieve in output')
mode_group.add_argument('--load_rag_db', type = str, default = None, help = 'Load a RAG database')
mode_group.add_argument('--load_rag_db_idx', type = str, default = None, help = 'Load a RAG database index')
mode_group.add_argument('--normalized_rag_feature', action='store_true', default=None, help='Enable normalization for RAG features and signals')
mode_group.add_argument('--dev', action='store_true', default=None, help='Development mode')
mode_group.add_argument('--log', action='store_true', default=None, help='Enable logging')

### Distributed Training
dist_group = parser.add_argument_group("Distributed Training")
dist_group.add_argument("--dis", action="store_true", default=None, help="Enable distributed training")
Expand Down
73 changes: 50 additions & 23 deletions ecg_bench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.multiprocessing as mp
from datasets import load_dataset
from huggingface_hub import HfFolder, login
from torch.optim import Adam
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -114,6 +115,7 @@ def create_save_path(args, fm):
args.retrieved_information,
args.rag_k,
args.rag_prompt_mode,
args.normalized_rag_feature
])

model_params.append(encoder_in)
Expand Down Expand Up @@ -235,9 +237,12 @@ def run_inference(model, test_loader, tokenizer, args, train_utils):
all_seed_results.append(seed_results)

# Construct filename based on args.rag
filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.blackout}_{args.no_signal}.json"

with open(f"{args.checkpoint}/{filename}", "w") as f:
if args.rag:
filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json"
else:
filename = f"seed_{seed}_{args.perturb}_{args.rag}.json"

with open(f"{args.checkpoint}/{filename}", 'w') as f:
json.dump({
"averages": seed_results["metrics"],
"qa_results": seed_results["qa_results"],
Expand All @@ -248,7 +253,13 @@ def run_inference(model, test_loader, tokenizer, args, train_utils):
print(f"Statistical results: {statistical_results}")

# Update statistical results filename similarly
stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.blackout}_{args.no_signal}.json"
if args.rag:
stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json"
else:
stat_filename = f"statistical_results_{args.perturb}_{args.rag}.json"

with open(f"{args.checkpoint}/{stat_filename}", 'w') as f:
json.dump(statistical_results, f)

with open(f"{args.checkpoint}/{stat_filename}", "w") as f:
json.dump(statistical_results, f)
Expand Down Expand Up @@ -285,25 +296,41 @@ def main(rank, world_size):

print(f"Total number of parameters: {train_utils.count_parameters(model)}")

if args.train:
optimizer_class = train_utils.get_optimizer_class(args.optimizer)
optimizer = ScheduledOptim(
optimizer_class(filter(lambda x: x.requires_grad, model.parameters()),
betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay),
model_object["model_hidden_size"], args)
train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch)
print(f"Length of Train Data: {len(train_data)}")
elif args.inference:
test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch)
print(f"Length of Test Data: {len(test_data)}")

if args.train == "first":
data = train_data.select(range(800000))
elif args.train in ["second", "end2end"]:
data = train_data.select(range(400000))
elif args.inference in ["second", "end2end"]:
data = test_data.select(range(20000))
print("Length of Dataset Considered:", len(data))
# if args.train:
# optimizer_class = train_utils.get_optimizer_class(args.optimizer)
# optimizer = ScheduledOptim(
# optimizer_class(filter(lambda x: x.requires_grad, model.parameters()),
# betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay),
# model_object["model_hidden_size"], args)
# train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch)
# print(f"Length of Train Data: {len(train_data)}")
# elif args.inference:
# test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch)
# print(f"Length of Test Data: {len(test_data)}")

# if args.train == "first":
# data = train_data.select(range(800000))
# elif args.train in ["second", "end2end"]:
# data = train_data.select(range(400000))
# elif args.inference in ["second", "end2end"]:
# data = test_data.select(range(20000))
# print("Length of Dataset Considered:", len(data))

optimizer = ScheduledOptim(
Adam(filter(lambda x: x.requires_grad, model.parameters()),
betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay),
model_object['model_hidden_size'], args)

json_data_file = fm.open_json(f'./data/{args.data}.json')
train_data, test_data = train_utils.split_dataset(json_data_file)

if args.train == 'first':
data = train_data[:800000]
elif args.train in ['second', 'end2end']:
data = train_data[:400000]
elif args.inference in ['second', 'end2end']:
data = test_data[:20000]
print('Length of Dataset:', len(data))

if args.train == "first":
dataset = FirstStageECGDataset(
Expand Down
2 changes: 1 addition & 1 deletion ecg_bench/models/encoder_llm/encoder_free_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def generate_chat(self, input_ids, attention_mask, tokenizer, encoder_out=None,
tokenizer=tokenizer,
inputs_embeds=llm_embeddings,
)
return out
return out
2 changes: 1 addition & 1 deletion ecg_bench/models/encoder_llm/llava_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def generate_chat(self, input_ids, attention_mask, tokenizer, encoder_out=None,
tokenizer=tokenizer,
inputs_embeds=llm_embeddings,
)
return out
return out
174 changes: 174 additions & 0 deletions ecg_bench/organize_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import glob
import json
from collections import defaultdict
from ecg_bench.config import get_args

def extract_file_info(file):
filename = file.split('/')[-1]
parts = filename.split('_')

if filename.startswith('seed_'):
# seed_{seed}_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
seed_num = int(parts[1])
perturb = parts[2]
rag_used = parts[3] == 'True'

if rag_used:
retrieval_base = parts[4]
retrieved_information = parts[5]
rag_k = int(parts[6])
rag_prompt_mode = parts[7]+parts[8]
normalized_rag_feature = parts[9].split('.')[0]
else:
retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
rag_k = None

is_seed = True
else:
# statistical_results_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json
perturb = parts[2]
rag_used = parts[3] == 'True'

if rag_used:
retrieval_base = parts[4]
retrieved_information = parts[5]
rag_k = int(parts[6])
rag_prompt_mode = parts[7]+parts[8]
normalized_rag_feature = parts[9].split('.')[0]
else:
retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None
rag_k = None

is_seed = False
seed_num = None

return {
'rag_used': rag_used,
'rag_k': rag_k,
'is_seed': is_seed,
'seed_num': seed_num,
'perturb': perturb,
'retrieval_base': retrieval_base,
'retrieved_information': retrieved_information,
'rag_prompt_mode': rag_prompt_mode,
'normalized_rag_feature': normalized_rag_feature
}

def process_seed_data(data):
averages = data['averages']
metrics = {}
for metric, value in averages.items():
if metric == 'ROUGE':
metrics[metric] = value['rouge-l']
elif metric == 'BERTSCORE':
metrics[metric] = sum(value['hf-f1']) / len(value['hf-f1'])
else:
metrics[metric] = value
return metrics

def collect_results(json_files):
individual_seeds_no_rag = {}
statistical_no_rag = {}
individual_seeds_rag = defaultdict(dict)
statistical_rag = {}
config_info_no_rag = None
config_info_rag = {}

for file in json_files:
info = extract_file_info(file)
with open(file, 'r') as f:
data = json.load(f)

if info['is_seed']:
metrics = process_seed_data(data)
if info['rag_used']:
rag_key = (
info['rag_k'],
info['retrieval_base'],
info['retrieved_information'],
info['rag_prompt_mode'],
info['normalized_rag_feature']
)
individual_seeds_rag[rag_key][info['seed_num']] = metrics
config_info_rag[rag_key] = info
else:
individual_seeds_no_rag[info['seed_num']] = metrics
config_info_no_rag = info
else:
if info['rag_used']:
rag_key = (
info['rag_k'],
info['retrieval_base'],
info['retrieved_information'],
info['rag_prompt_mode'],
info['normalized_rag_feature']
)
statistical_rag[rag_key] = data
config_info_rag[rag_key] = info

else:
statistical_no_rag = data
config_info_no_rag = info

return (individual_seeds_no_rag, statistical_no_rag,
individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag)

def print_seed_results(title, seed_dict, config_info=None):
if not seed_dict:
return
print(title)
if config_info:
print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}")
for seed in sorted(seed_dict.keys()):
print(f" Seed {seed}:")
for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']:
value = seed_dict[seed][metric] * 100 # Scale to 0-100
print(f" {metric}: {value:.2f}")
print('--------------------------------')

def print_statistical_results(title, stats_dict, config_info=None):
if not stats_dict:
return
print(title)
if config_info:
print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}")
for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']:
value = (stats_dict['ROUGE']['rouge-l'] if metric == 'ROUGE' else
stats_dict['BERTSCORE']['hf-f1'] if metric == 'BERTSCORE' else
stats_dict[metric])
print(f" {metric}:")
for k, v in value.items():
if k != 'raw_values':
formatted_v = f"[{v[0]:.2f}, {v[1]:.2f}]" if isinstance(v, list) else f"{v:.2f}"
print(f" {k}: {formatted_v}")
print('--------------------------------')

def main():
args = get_args()
dataset_name = args.checkpoint.split('/')[2]
model_name = args.checkpoint.split('/')[4].split('_')[0]

print(f"Organizing results for {dataset_name} with {model_name}")

json_files = glob.glob(f'{args.checkpoint}/*.json')
if not json_files:
print("No results files found.")
print('================================================')
return

(individual_seeds_no_rag, statistical_no_rag,
individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) = collect_results(json_files)

print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag, config_info_no_rag)
print_statistical_results("Statistical Results without RAG:", statistical_no_rag, config_info_no_rag)

for rag_key in sorted(individual_seeds_rag.keys()):
config_info = config_info_rag.get(rag_key)
print_seed_results(f"Individual Seed Results with RAG config={rag_key}:", individual_seeds_rag[rag_key], config_info)
print_statistical_results(f"Statistical Results with RAG config={rag_key}:", statistical_rag.get(rag_key, {}), config_info)


print('================================================')

if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion ecg_bench/runners/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def tester_chat(model, dataloader, tokenizer, args, train_utils):
signal_id_index = batch["signal_id_index"].item()
offset = 0
for conv_turn in assistant_ranges:
print("conv_turn", conv_turn)
# print("conv_turn", conv_turn)
start = conv_turn["start"] + 4 + offset
end = conv_turn["end"] + 1 + offset
curr_input_ids = chat_input_ids[:, :start]
Expand Down
19 changes: 14 additions & 5 deletions ecg_bench/scripts/org_results.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
#!/bin/bash

datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here
checkpoints=(
"llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_report_5_False"
)
# data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250")
data=("ecg_instruct_45k_mapped_1250")
# data=("ecg-qa_mimic-iv-ecg_mapped_1250")
# data=("ecg-qa_ptbxl_mapped_1250")
# data=("pretrain_mimic_mapped_1250")
# retrieval_base="feature"
# retrieved_information="combined"
# rag_k=1
# rag_prompt_mode="system_prompt"
# normalized_rag_features=True

checkpoints='llama-3.2-1b-instruct_adam_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_1_None_None_feature_report_1_system_prompt_None_False'
# checkpoints='qwen2.5-3b_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False'

for d in "${data[@]}"; do
if [ "$d" = "ecg_instruct_pulse_mapped_1250" ]; then
Expand All @@ -14,6 +23,6 @@ for d in "${data[@]}"; do

for ckpt in "${checkpoints[@]}"; do
python organize_results.py \
--checkpoint=./runs/$d/0/$ckpt
--checkpoint=./runs/$d/0/$ckpt/
done
done
Loading