diff --git a/README.md b/README.md index 7ab065c..35e38f2 100644 --- a/README.md +++ b/README.md @@ -839,6 +839,11 @@ This is a list of TODOs for the repository. If you are interested in contributin - [ ] Curate higher quality instruction tuning and reasoning datasets for ELMs. - [ ] Expand upon current naive distributed training setting to include more efficient and explicit distributed training strategies (i.e., data, tensor, context, pipeline, and expert parallelism as noted in [here](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=5d_parallelism_in_a_nutshell)). - [x] Add option for data mixing. +- [x] Adjust feature selection for RAG. +- [x] Apply normalization to RAG database. +- [ ] Cross Dataset Ablation. +- [ ] Apply RAG with encoder methods +- [ ] Add "Only feature" retrieval option for RAG - [x] For preprocessing, stratify based on patient, such that no overlapping patients between train and test. - [x] Add official splits for benchmarking. - [x] Upload to huggingface datasets and use huggingface datasets data loading in main. diff --git a/ecg_bench/config.py b/ecg_bench/config.py index bc3fd4a..1978a61 100644 --- a/ecg_bench/config.py +++ b/ecg_bench/config.py @@ -54,22 +54,23 @@ def get_args(): peft_group.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout") ### Mode and Environment - mode_group = parser.add_argument_group("Mode and Environment") - mode_group.add_argument("--train", type=str, default=None, choices=["first", "second", "end2end"], help="Training mode") - mode_group.add_argument("--inference", type=str, default=None, choices=["second", "end2end"], help="Inference mode") - mode_group.add_argument("--post_train", action="store_true", default=None, help="Post-training mode") - mode_group.add_argument("--train_encoder", action="store_true", default=None, help="Train encoder too") - mode_group.add_argument("--interpret", action="store_true", default=None, help="Interpret mode") - mode_group.add_argument("--rag", action="store_true", default=None, help="RAG mode") - mode_group.add_argument("--rag_k", type=int, default=1, help="RAG k") - mode_group.add_argument("--rag_prompt_mode", type=str, default="system_prompt", choices=["system_prompt", "user_query"], help="How to integrate RAG results: in system prompt, user query") - mode_group.add_argument("--retrieval_base", type=str, default="combined", choices=["signal", "feature", "combined"], help="Retrieval base for similarity calculation") - mode_group.add_argument("--retrieved_information", type=str, default="combined", choices=["feature", "report", "combined"], help="Type of information to retrieve in output") - mode_group.add_argument("--load_rag_db", type = str, default = None, help = "Load a RAG database") - mode_group.add_argument("--load_rag_db_idx", type = str, default = None, help = "Load a RAG database index") - mode_group.add_argument("--dev", action="store_true", default=None, help="Development mode") - mode_group.add_argument("--log", action="store_true", default=None, help="Enable logging") - + mode_group = parser.add_argument_group('Mode and Environment') + mode_group.add_argument('--train', type=str, default=None, choices=['first', 'second', 'end2end'], help='Training mode') + mode_group.add_argument('--inference', type=str, default=None, choices=['second', 'end2end'], help='Inference mode') + mode_group.add_argument('--post_train', action='store_true', default=None, help='Post-training mode') + mode_group.add_argument('--train_encoder', action='store_true', default=None, help='Train encoder too') + mode_group.add_argument('--interpret', action='store_true', default=None, help='Interpret mode') + mode_group.add_argument('--rag', action='store_true', default=None, help='RAG mode') + mode_group.add_argument('--rag_k', type=int, default=1, help='RAG k') + mode_group.add_argument('--rag_prompt_mode', type=str, default='system_prompt', choices=['system_prompt', 'user_query'], help='How to integrate RAG results: in system prompt, user query') + mode_group.add_argument('--retrieval_base', type=str, default='combined', choices=['signal', 'feature', 'combined'], help='Retrieval base for similarity calculation') + mode_group.add_argument('--retrieved_information', type=str, default='combined', choices=['feature', 'report', 'combined'], help='Type of information to retrieve in output') + mode_group.add_argument('--load_rag_db', type = str, default = None, help = 'Load a RAG database') + mode_group.add_argument('--load_rag_db_idx', type = str, default = None, help = 'Load a RAG database index') + mode_group.add_argument('--normalized_rag_feature', action='store_true', default=None, help='Enable normalization for RAG features and signals') + mode_group.add_argument('--dev', action='store_true', default=None, help='Development mode') + mode_group.add_argument('--log', action='store_true', default=None, help='Enable logging') + ### Distributed Training dist_group = parser.add_argument_group("Distributed Training") dist_group.add_argument("--dis", action="store_true", default=None, help="Enable distributed training") diff --git a/ecg_bench/main.py b/ecg_bench/main.py index 6d4f85d..8f6e7cb 100644 --- a/ecg_bench/main.py +++ b/ecg_bench/main.py @@ -12,6 +12,7 @@ import torch.multiprocessing as mp from datasets import load_dataset from huggingface_hub import HfFolder, login +from torch.optim import Adam from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -114,6 +115,7 @@ def create_save_path(args, fm): args.retrieved_information, args.rag_k, args.rag_prompt_mode, + args.normalized_rag_feature ]) model_params.append(encoder_in) @@ -235,9 +237,12 @@ def run_inference(model, test_loader, tokenizer, args, train_utils): all_seed_results.append(seed_results) # Construct filename based on args.rag - filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.blackout}_{args.no_signal}.json" - - with open(f"{args.checkpoint}/{filename}", "w") as f: + if args.rag: + filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json" + else: + filename = f"seed_{seed}_{args.perturb}_{args.rag}.json" + + with open(f"{args.checkpoint}/{filename}", 'w') as f: json.dump({ "averages": seed_results["metrics"], "qa_results": seed_results["qa_results"], @@ -248,7 +253,13 @@ def run_inference(model, test_loader, tokenizer, args, train_utils): print(f"Statistical results: {statistical_results}") # Update statistical results filename similarly - stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.blackout}_{args.no_signal}.json" + if args.rag: + stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json" + else: + stat_filename = f"statistical_results_{args.perturb}_{args.rag}.json" + + with open(f"{args.checkpoint}/{stat_filename}", 'w') as f: + json.dump(statistical_results, f) with open(f"{args.checkpoint}/{stat_filename}", "w") as f: json.dump(statistical_results, f) @@ -285,25 +296,41 @@ def main(rank, world_size): print(f"Total number of parameters: {train_utils.count_parameters(model)}") - if args.train: - optimizer_class = train_utils.get_optimizer_class(args.optimizer) - optimizer = ScheduledOptim( - optimizer_class(filter(lambda x: x.requires_grad, model.parameters()), - betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), - model_object["model_hidden_size"], args) - train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch) - print(f"Length of Train Data: {len(train_data)}") - elif args.inference: - test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch) - print(f"Length of Test Data: {len(test_data)}") - - if args.train == "first": - data = train_data.select(range(800000)) - elif args.train in ["second", "end2end"]: - data = train_data.select(range(400000)) - elif args.inference in ["second", "end2end"]: - data = test_data.select(range(20000)) - print("Length of Dataset Considered:", len(data)) + # if args.train: + # optimizer_class = train_utils.get_optimizer_class(args.optimizer) + # optimizer = ScheduledOptim( + # optimizer_class(filter(lambda x: x.requires_grad, model.parameters()), + # betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), + # model_object["model_hidden_size"], args) + # train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch) + # print(f"Length of Train Data: {len(train_data)}") + # elif args.inference: + # test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch) + # print(f"Length of Test Data: {len(test_data)}") + + # if args.train == "first": + # data = train_data.select(range(800000)) + # elif args.train in ["second", "end2end"]: + # data = train_data.select(range(400000)) + # elif args.inference in ["second", "end2end"]: + # data = test_data.select(range(20000)) + # print("Length of Dataset Considered:", len(data)) + + optimizer = ScheduledOptim( + Adam(filter(lambda x: x.requires_grad, model.parameters()), + betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), + model_object['model_hidden_size'], args) + + json_data_file = fm.open_json(f'./data/{args.data}.json') + train_data, test_data = train_utils.split_dataset(json_data_file) + + if args.train == 'first': + data = train_data[:800000] + elif args.train in ['second', 'end2end']: + data = train_data[:400000] + elif args.inference in ['second', 'end2end']: + data = test_data[:20000] + print('Length of Dataset:', len(data)) if args.train == "first": dataset = FirstStageECGDataset( diff --git a/ecg_bench/models/encoder_llm/encoder_free_style.py b/ecg_bench/models/encoder_llm/encoder_free_style.py index db75cd1..1ddd1ca 100644 --- a/ecg_bench/models/encoder_llm/encoder_free_style.py +++ b/ecg_bench/models/encoder_llm/encoder_free_style.py @@ -45,4 +45,4 @@ def generate_chat(self, input_ids, attention_mask, tokenizer, encoder_out=None, tokenizer=tokenizer, inputs_embeds=llm_embeddings, ) - return out + return out \ No newline at end of file diff --git a/ecg_bench/models/encoder_llm/llava_style.py b/ecg_bench/models/encoder_llm/llava_style.py index baef382..e79aec8 100644 --- a/ecg_bench/models/encoder_llm/llava_style.py +++ b/ecg_bench/models/encoder_llm/llava_style.py @@ -44,4 +44,4 @@ def generate_chat(self, input_ids, attention_mask, tokenizer, encoder_out=None, tokenizer=tokenizer, inputs_embeds=llm_embeddings, ) - return out + return out \ No newline at end of file diff --git a/ecg_bench/organize_results.py b/ecg_bench/organize_results.py new file mode 100644 index 0000000..922cfd0 --- /dev/null +++ b/ecg_bench/organize_results.py @@ -0,0 +1,174 @@ +import glob +import json +from collections import defaultdict +from ecg_bench.config import get_args + +def extract_file_info(file): + filename = file.split('/')[-1] + parts = filename.split('_') + + if filename.startswith('seed_'): + # seed_{seed}_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json + seed_num = int(parts[1]) + perturb = parts[2] + rag_used = parts[3] == 'True' + + if rag_used: + retrieval_base = parts[4] + retrieved_information = parts[5] + rag_k = int(parts[6]) + rag_prompt_mode = parts[7]+parts[8] + normalized_rag_feature = parts[9].split('.')[0] + else: + retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None + rag_k = None + + is_seed = True + else: + # statistical_results_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json + perturb = parts[2] + rag_used = parts[3] == 'True' + + if rag_used: + retrieval_base = parts[4] + retrieved_information = parts[5] + rag_k = int(parts[6]) + rag_prompt_mode = parts[7]+parts[8] + normalized_rag_feature = parts[9].split('.')[0] + else: + retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None + rag_k = None + + is_seed = False + seed_num = None + + return { + 'rag_used': rag_used, + 'rag_k': rag_k, + 'is_seed': is_seed, + 'seed_num': seed_num, + 'perturb': perturb, + 'retrieval_base': retrieval_base, + 'retrieved_information': retrieved_information, + 'rag_prompt_mode': rag_prompt_mode, + 'normalized_rag_feature': normalized_rag_feature + } + +def process_seed_data(data): + averages = data['averages'] + metrics = {} + for metric, value in averages.items(): + if metric == 'ROUGE': + metrics[metric] = value['rouge-l'] + elif metric == 'BERTSCORE': + metrics[metric] = sum(value['hf-f1']) / len(value['hf-f1']) + else: + metrics[metric] = value + return metrics + +def collect_results(json_files): + individual_seeds_no_rag = {} + statistical_no_rag = {} + individual_seeds_rag = defaultdict(dict) + statistical_rag = {} + config_info_no_rag = None + config_info_rag = {} + + for file in json_files: + info = extract_file_info(file) + with open(file, 'r') as f: + data = json.load(f) + + if info['is_seed']: + metrics = process_seed_data(data) + if info['rag_used']: + rag_key = ( + info['rag_k'], + info['retrieval_base'], + info['retrieved_information'], + info['rag_prompt_mode'], + info['normalized_rag_feature'] + ) + individual_seeds_rag[rag_key][info['seed_num']] = metrics + config_info_rag[rag_key] = info + else: + individual_seeds_no_rag[info['seed_num']] = metrics + config_info_no_rag = info + else: + if info['rag_used']: + rag_key = ( + info['rag_k'], + info['retrieval_base'], + info['retrieved_information'], + info['rag_prompt_mode'], + info['normalized_rag_feature'] + ) + statistical_rag[rag_key] = data + config_info_rag[rag_key] = info + + else: + statistical_no_rag = data + config_info_no_rag = info + + return (individual_seeds_no_rag, statistical_no_rag, + individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) + +def print_seed_results(title, seed_dict, config_info=None): + if not seed_dict: + return + print(title) + if config_info: + print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}") + for seed in sorted(seed_dict.keys()): + print(f" Seed {seed}:") + for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']: + value = seed_dict[seed][metric] * 100 # Scale to 0-100 + print(f" {metric}: {value:.2f}") + print('--------------------------------') + +def print_statistical_results(title, stats_dict, config_info=None): + if not stats_dict: + return + print(title) + if config_info: + print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}") + for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']: + value = (stats_dict['ROUGE']['rouge-l'] if metric == 'ROUGE' else + stats_dict['BERTSCORE']['hf-f1'] if metric == 'BERTSCORE' else + stats_dict[metric]) + print(f" {metric}:") + for k, v in value.items(): + if k != 'raw_values': + formatted_v = f"[{v[0]:.2f}, {v[1]:.2f}]" if isinstance(v, list) else f"{v:.2f}" + print(f" {k}: {formatted_v}") + print('--------------------------------') + +def main(): + args = get_args() + dataset_name = args.checkpoint.split('/')[2] + model_name = args.checkpoint.split('/')[4].split('_')[0] + + print(f"Organizing results for {dataset_name} with {model_name}") + + json_files = glob.glob(f'{args.checkpoint}/*.json') + if not json_files: + print("No results files found.") + print('================================================') + return + + (individual_seeds_no_rag, statistical_no_rag, + individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) = collect_results(json_files) + + print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag, config_info_no_rag) + print_statistical_results("Statistical Results without RAG:", statistical_no_rag, config_info_no_rag) + + for rag_key in sorted(individual_seeds_rag.keys()): + config_info = config_info_rag.get(rag_key) + print_seed_results(f"Individual Seed Results with RAG config={rag_key}:", individual_seeds_rag[rag_key], config_info) + print_statistical_results(f"Statistical Results with RAG config={rag_key}:", statistical_rag.get(rag_key, {}), config_info) + + + print('================================================') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ecg_bench/runners/inference.py b/ecg_bench/runners/inference.py index 60db82e..01987db 100644 --- a/ecg_bench/runners/inference.py +++ b/ecg_bench/runners/inference.py @@ -27,7 +27,7 @@ def tester_chat(model, dataloader, tokenizer, args, train_utils): signal_id_index = batch["signal_id_index"].item() offset = 0 for conv_turn in assistant_ranges: - print("conv_turn", conv_turn) + # print("conv_turn", conv_turn) start = conv_turn["start"] + 4 + offset end = conv_turn["end"] + 1 + offset curr_input_ids = chat_input_ids[:, :start] diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index b2066cd..b2811f1 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -1,9 +1,18 @@ #!/bin/bash -datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here -checkpoints=( - "llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_report_5_False" -) +# data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250") +data=("ecg_instruct_45k_mapped_1250") +# data=("ecg-qa_mimic-iv-ecg_mapped_1250") +# data=("ecg-qa_ptbxl_mapped_1250") +# data=("pretrain_mimic_mapped_1250") +# retrieval_base="feature" +# retrieved_information="combined" +# rag_k=1 +# rag_prompt_mode="system_prompt" +# normalized_rag_features=True + +checkpoints='llama-3.2-1b-instruct_adam_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_1_None_None_feature_report_1_system_prompt_None_False' +# checkpoints='qwen2.5-3b_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' for d in "${data[@]}"; do if [ "$d" = "ecg_instruct_pulse_mapped_1250" ]; then @@ -14,6 +23,6 @@ for d in "${data[@]}"; do for ckpt in "${checkpoints[@]}"; do python organize_results.py \ - --checkpoint=./runs/$d/0/$ckpt + --checkpoint=./runs/$d/0/$ckpt/ done done \ No newline at end of file diff --git a/ecg_bench/scripts/train_1st.sh b/ecg_bench/scripts/train_1st.sh index cc12c42..2e91b7c 100644 --- a/ecg_bench/scripts/train_1st.sh +++ b/ecg_bench/scripts/train_1st.sh @@ -1,18 +1,19 @@ #!/bin/bash -models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") +# models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") +models=("merl") +# data=("ecg-qa-mimic-iv-ecg-250-1250") +# data=("ecg_instruct_45k_mapped_1250") ### MULTI GPU for model in "${models[@]}"; do python main.py \ - --data=ecg-qa-mimic-iv-ecg-250-1250 \ + --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ --model=$model \ --device=cuda:4 \ --train=first \ - --batch_size=256 \ + --batch_size=64 \ --seg_len=1250 \ - --lr=8e-5 \ - --weight_decay=1e-4 \ --epochs=50 \ --instance_normalize \ --attn_implementation=flash_attention_2 \ @@ -20,19 +21,20 @@ for model in "${models[@]}"; do done -models=("vit" "clip" "siglip") -for model in "${models[@]}"; do - python main.py \ - --data=test-ecg \ - --model=$model \ - --device=cuda:6 \ - --train=first \ - --batch_size=64 \ - --seg_len=1250 \ - --epochs=50 \ - --instance_normalize \ - --attn_implementation=flash_attention_2 \ - --image \ - --dev -done \ No newline at end of file +# models=("vit" "clip" "siglip") + +# for model in "${models[@]}"; do +# python main.py \ +# --data=test-ecg \ +# --model=$model \ +# --device=cuda:6 \ +# --train=first \ +# --batch_size=64 \ +# --seg_len=1250 \ +# --epochs=50 \ +# --instance_normalize \ +# --attn_implementation=flash_attention_2 \ +# --image \ +# --dev +# done \ No newline at end of file diff --git a/ecg_bench/scripts/train_2nd.sh b/ecg_bench/scripts/train_2nd.sh index cfa31dd..a757bba 100644 --- a/ecg_bench/scripts/train_2nd.sh +++ b/ecg_bench/scripts/train_2nd.sh @@ -1,9 +1,14 @@ #!/usr/bin/env bash # ------------------- CONFIGURABLE LISTS ------------------- -encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") -encoders_checkpoints=("stmem_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "merl_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mlae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mtae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None") -llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct") -datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here +# encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") +encoders=("merl") +encoders_checkpoints=("merl_adam_64_50_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_1_None_None_False") +# llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct") +llms=("llama-3.2-1b-instruct") +# datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here +# datasets=("ecg_instruct_45k_mapped_1250") +datasets=("ecg-qa_mimic-iv-ecg_mapped_1250") +# datasets=("ecg-qa_mimic-iv-ecg_mapped_1250" "ecg-qa_ptbxl_mapped_1250") # ---------------------------------------------------------- for data in "${datasets[@]}"; do @@ -26,7 +31,7 @@ for data in "${datasets[@]}"; do python main.py \ --data="$data" \ --model="${encoder}_${llm}" \ - --device=cuda:7 \ + --device=cuda:3 \ --train=second \ --batch_size=2 \ --seg_len=1250 \ @@ -37,25 +42,8 @@ for data in "${datasets[@]}"; do --attn_implementation=flash_attention_2 \ --system_prompt=./data/system_prompt_e2e.txt \ $([ -n "$checkpoint_path" ] && echo "--encoder_checkpoint=$checkpoint_path") \ - --dev + --log done done done - -models=("vit" "clip" "siglip" ) - -for model in "${models[@]}"; do - python main.py \ - --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ - --model=$model \ - --device=cuda:6 \ - --train=first \ - --batch_size=8 \ - --seg_len=1250 \ - --epochs=2 \ - --instance_normalize \ - --attn_implementation=flash_attention_2 \ - --image \ - --log -done \ No newline at end of file diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index c6c5611..e3a526a 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -81,9 +81,11 @@ def create_position_ids(self, padded_sequence): return position_ids def get_qa(self, altered_text): - if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}": + # if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}": + if self.args.data == f"pretrain_mimic_mapped_{self.args.seg_len}": question, answer = altered_text[0]["value"].replace("\n", "").replace("", ""), altered_text[1]["value"] - elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]: + # elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]: + elif self.args.data in [f"ecg-qa_mimic-iv-ecg_mapped_{self.args.seg_len}", f"ecg-qa_ptbxl_mapped_{self.args.seg_len}"]: question_type, question, answer = altered_text[0], altered_text[1], altered_text[2] answer = " ".join(answer) if isinstance(answer, list) else answer return question, answer @@ -100,29 +102,36 @@ def create_special_tokens(self): self.pad_id = self.llm_tokenizer.convert_tokens_to_ids(self.llm_tokenizer.pad_token) def setup_conversation_template(self, signal = None): - if "llama" in self.args.model: - conv = get_conv_template("llama-3") - elif "qwen" in self.args.model: - conv = get_conv_template("qwen-7b-chat") - elif "gemma" in self.args.model: - conv = get_conv_template("gemma") - - if "gemma" not in self.args.model and ("qwen" in self.args.model or "llama" in self.args.model): - if self.args.rag: - rag_results = self.rag_db.search_similar(query_signal=signal, k=self.args.rag_k, mode="signal") - filtered_rag_results = self.rag_db.format_search(rag_results) + if 'llama' in self.args.model: + conv = get_conv_template('llama-3') + elif 'qwen' in self.args.model: + conv = get_conv_template('qwen-7b-chat') + elif 'gemma' in self.args.model: + conv = get_conv_template('gemma') + feature=None + if self.args.rag and self.args.retrieval_base in ['feature', 'combined']: + original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) + feature=original_feature + if self.args.normalized_rag_feature: + feature = self.rag_db.query_feature_normalization(original_feature) + signal = self.rag_db.query_signal_lead_normalization(signal) + + if 'gemma' not in self.args.model and ('qwen' in self.args.model or 'llama' in self.args.model): + if self.args.rag and self.args.rag_prompt_mode == 'system_prompt': + rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base) + filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information) modified_system_prompt = f"{self.system_prompt}\n{filtered_rag_results}" if self.args.dev: print("filtered_rag_results", filtered_rag_results) print("modified_system_prompt", modified_system_prompt) - conv.set_system_message(modified_system_prompt) else: conv.set_system_message(self.system_prompt) return conv def process_altered_text(self, altered_text): - if self.args.data not in [f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}", + if self.args.data not in [#f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}", + f"ecg_instruct_45k_mapped_{self.args.seg_len}", f"ecg-instruct-pulse-{self.args.target_sf}-{self.args.seg_len}", f"ecg-bench-pulse-{self.args.target_sf}-{self.args.seg_len}"]: question, answer = self.get_qa(altered_text) @@ -137,13 +146,27 @@ def append_messages_to_conv(self, conv, altered_text, signal=None): for message in altered_text: is_human = message["from"].lower() in ["human", "user"] role = conv.roles[0] if is_human else conv.roles[1] - message_value = message["value"].replace("\n", "") - message_value = message_value.replace("\n", "") - message_value = message_value.replace("", "") - message_value = message_value.replace("", "") - message_value = message_value.replace("image", "signal").replace("Image", "Signal") + message_value = message['value'].replace('\n', '') + message_value = message_value.replace('\n', '') + message_value = message_value.replace('', '') + message_value = message_value.replace('', '') + message_value = message_value.replace('image', 'signal').replace('Image', 'Signal') + if self.args.rag and (self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']): + original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) + feature=original_feature + if self.args.normalized_rag_feature: + feature = self.rag_db.query_feature_normalization(original_feature) + signal = self.rag_db.query_signal_lead_normalization(signal) if is_human and count == 0: - message_value = f"\n{message_value}" + if self.args.rag and self.args.rag_prompt_mode == 'user_query': + rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base) + filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information) + if self.args.retrieved_information == 'combined': + message_value = f"\nFeature Information:\n{self.rag_db.convert_features_to_structured(original_feature)}\n\n{filtered_rag_results}Question:\n{message_value}" + elif self.args.retrieved_information == 'report': + message_value = f"\n{filtered_rag_results}Question:\n{message_value}" + else: + message_value = f"Question:\n{message_value}" count += 1 conv.append_message(role, message_value) return conv @@ -380,7 +403,7 @@ def prepare_training_end2end(self, ecg_signal, altered_text): conv = self.setup_conversation_template(signal=ecg_signal) altered_text = self.process_altered_text(altered_text) conv = self.append_messages_to_conv(conv, altered_text, ecg_signal) - + tokens_before, tokens_after = self.get_input_tokens(conv) symbol_signal = self.train_utils.ecg_tokenizer_utils._to_symbol_string(ecg_signal) @@ -405,7 +428,6 @@ def prepare_training_end2end(self, ecg_signal, altered_text): if len(input_ids) < self.args.pad_to_max: padding_length = self.args.pad_to_max - len(input_ids) input_ids = [self.llm_tokenizer.pad_token_id] * padding_length + input_ids - labels = self.create_labels_from_responses(input_ids, altered_text) if self.args.dev: diff --git a/ecg_bench/utils/preprocess_utils.py b/ecg_bench/utils/preprocess_utils.py index 2be42df..53d4405 100644 --- a/ecg_bench/utils/preprocess_utils.py +++ b/ecg_bench/utils/preprocess_utils.py @@ -825,7 +825,62 @@ def extract_features(self, ecg): features.append(np.sqrt(np.mean(np.square(np.diff(lead_signal))))) # Root mean square of successive differences return np.array(features) - + + def extract_rag_features(self, ecg): + """ + Extract a subset of features for RAG applications. + Keeps only: max, min, dominant_frequency, total_power, spectral_centroid, + peak_frequency_power, Heart Rate Features, Wavelet Features, average_absolute_difference, root_mean_square_difference + """ + features = [] + + for lead in range(ecg.shape[0]): + lead_signal = ecg[lead, :] + + # Basic statistical features (only max and min) + features.extend([ + np.max(lead_signal), + np.min(lead_signal) + ]) + + # Frequency domain features + freqs, psd = signal.welch(lead_signal, fs=self.target_sf, nperseg=min(1024, len(lead_signal))) + total_power = np.sum(psd) + features.extend([ + total_power, # Total power + np.max(psd), # Peak frequency power + freqs[np.argmax(psd)], # Dominant frequency + ]) + + # Spectral centroid with NaN handling + if total_power > 0: + spectral_centroid = np.sum(freqs * psd) / total_power + else: + spectral_centroid = 0 + features.append(spectral_centroid) + + # Find peaks with robust thresholding + if np.max(lead_signal) != np.min(lead_signal): # Avoid division by zero + peak_height = 0.3 * (np.max(lead_signal) - np.min(lead_signal)) + np.min(lead_signal) + min_distance = max(int(0.2 * self.target_sf), 1) # Ensure positive distance + peaks, _ = signal.find_peaks(lead_signal, height=peak_height, distance=min_distance) + else: + peaks = [] + + # Heart rate features + heart_rate_features = self._calculate_heart_rate_features(lead_signal, peaks) + features.extend(heart_rate_features) + + # Wavelet features + wavelet_features = self._calculate_wavelet_features(lead_signal) + features.extend(wavelet_features) + + # Non-linear features + features.append(np.mean(np.abs(np.diff(lead_signal)))) # Average absolute difference + features.append(np.sqrt(np.mean(np.square(np.diff(lead_signal))))) # Root mean square of successive differences + + return np.array(features) + def _calculate_heart_rate_features(self, ecg, peaks): if len(peaks) > 1: # Heart rate @@ -895,5 +950,59 @@ def find_st_deviation(self, ecg, peaks): if st_point < len(ecg): return ecg[st_point] - ecg[peaks[-1]] return 0 - - + + def signal_lead_normalization(ecg): + """ + Normalize each lead individually using z-score normalization. + """ + if ecg.shape[0] == 12: + ecg = ecg.T + transpose_back = True + else: + transpose_back = False + + normalized_ecg = np.zeros_like(ecg, dtype=np.float32) + + for lead_idx in range(12): + lead_signal = ecg[:, lead_idx] + lead_mean = np.mean(lead_signal) + lead_std = np.std(lead_signal) + 1e-10 + normalized_ecg[:, lead_idx] = (lead_signal - lead_mean) / lead_std + + if transpose_back: + normalized_ecg = normalized_ecg.T + + return normalized_ecg + + def feature_normalization(self, rag_features): + """ + Normalize RAG features using z-score normalization. + """ + features_per_lead = len(self.ecg_feature_list) + expected_total_features = 12 * features_per_lead + + if rag_features.ndim != 1: + raise ValueError(f"Expected 1D array, got shape {rag_features.shape}") + + if len(rag_features) != expected_total_features: + raise ValueError(f"Expected {expected_total_features} features for 12-lead ECG, got {len(rag_features)}") + + normalized_features = np.zeros_like(rag_features, dtype=np.float32) + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + feature_values.append(rag_features[feature_pos]) + + feature_values = np.array(feature_values) + + feature_mean = np.mean(feature_values) + feature_std = np.std(feature_values) + 1e-10 + + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std + + return normalized_features + diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 292d963..59c1036 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -13,13 +13,13 @@ def __init__(self, args, fm): self.args = args self.fm = fm self.ecg_feature_list = [ - "mean", - "std", + # "mean", + # "std", "max", "min", - "median", - "25th percentile", - "75th percentile", + # "median", + # "25th percentile", + # "75th percentile", "total power", "peak frequency power", "dominant frequency", @@ -38,60 +38,115 @@ def __init__(self, args, fm): "average absolute difference", "root mean square difference", ] - - print("Loading RAG database...") + + + print('Loading RAG database...') if self.args.create_rag_db: self.preprocessed_dir = f"./data/{self.args.base_data}/preprocessed_{self.args.seg_len}_{self.args.target_sf}" self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) - print("Creating RAG database...") - print("No RAG database found. Creating new one...") + self.feature_dim = 12* len(self.ecg_feature_list) + self.signal_dim = 12*self.args.seg_len + self.feature_weight=np.sqrt(self.signal_dim/self.feature_dim) + print('Creating RAG database...') + print('No RAG database found. Creating new one...') self.metadata = self.create_and_save_db() elif self.args.load_rag_db != None and self.args.load_rag_db_idx != None: print("Loading RAG database from file...") self.metadata = self.fm.open_json(self.args.load_rag_db) # Initialize feature extractor for potential feature extraction self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) - # Load the appropriate index based on retrieval_base - retrieval_base = getattr(self.args, "retrieval_base", "combined") - if retrieval_base in ["signal", "feature", "combined"]: - # If specific index file is provided, use it; otherwise construct path - if self.args.load_rag_db_idx.endswith(".index"): - self.index = faiss.read_index(self.args.load_rag_db_idx) - else: - # Construct index path based on retrieval_base - index_path = f"{self.args.load_rag_db_idx}/{retrieval_base}.index" - self.index = faiss.read_index(index_path) + self.feature_dim = 12*len(self.ecg_feature_list) + self.signal_dim = 12*self.args.seg_len + self.feature_weight=np.sqrt(self.signal_dim/self.feature_dim) + + if self.args.retrieval_base == 'signal': + self.signal_index = faiss.read_index(self.args.load_rag_db_idx) + elif self.args.retrieval_base == 'feature': + self.feature_index = faiss.read_index(self.args.load_rag_db_idx) + elif self.args.retrieval_base == 'combined': + self.combined_index = faiss.read_index(self.args.load_rag_db_idx) else: self.index = faiss.read_index(self.args.load_rag_db_idx) else: - print("Please either create a RAG datbse or load one in.") - - print("Metadata loaded.") - self.reports = [item["report"] for item in self.metadata] - self.file_paths = [item["file_path"] for item in self.metadata] - - print("RAG database loaded.") - # Get dimensions from first vector in index - first_vector = self.index.reconstruct(0) - self.feature_dim = 288 - self.signal_dim = len(first_vector) - self.feature_dim - print("Building sub-indices...") - self._build_sub_indices() - - print("features dim:", self.feature_dim) - print("signals dim:", self.signal_dim) - print("total samples:", len(self.reports)) - print("Index loaded.") + print('Please either create a RAG datbse or load one in.') + + print('Metadata loaded.') + self.reports = [item['report'] for item in self.metadata] + self.file_paths = [item['file_path'] for item in self.metadata] + print(f'RAG {self.args.retrieval_base} database loaded.') + + # print('Building sub-indices...') + # self._build_sub_indices() + + print('features dim:', self.feature_dim) + print('signals dim:', self.signal_dim) + print('total samples:', len(self.reports)) + print(f'Normalization enabled: {self.args.normalized_rag_feature}') + print(f'{self.args.retrieval_base} Index loaded.') + + def query_signal_lead_normalization(self, signal): + """ + Normalize each lead individually using z-score normalization. + """ + if signal.shape[0] == 12: + signal = signal.T + transpose_back = True + else: + transpose_back = False + + normalized_signal = np.zeros_like(signal, dtype=np.float32) + + for lead_idx in range(12): + lead_signal = signal[:, lead_idx] + lead_mean = np.mean(lead_signal) + lead_std = np.std(lead_signal) + 1e-10 + normalized_signal[:, lead_idx] = (lead_signal - lead_mean) / lead_std + + if transpose_back: + normalized_signal = normalized_signal.T + + return normalized_signal + + def query_feature_normalization(self, rag_features): + """ + Normalize RAG features using z-score normalization. + """ + expected_total_features = self.feature_dim + + if rag_features.ndim != 1: + raise ValueError(f"Expected 1D array, got shape {rag_features.shape}") + + if len(rag_features) != expected_total_features: + raise ValueError(f"Expected {expected_total_features} features for 12-lead ECG, got {len(rag_features)}") + + normalized_features = np.zeros_like(rag_features, dtype=np.float32) + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx + feature_values.append(rag_features[feature_pos]) + + feature_values = np.array(feature_values) + + feature_mean = np.mean(feature_values) + feature_std = np.std(feature_values) + 1e-10 + + for lead_idx in range(12): + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx + normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std + + return normalized_features def _build_sub_indices(self): - ntotal = self.index.ntotal - nlist = min(100, max(1, ntotal // 30)) - + ntotal = self.combined_index.ntotal + nlist=min(100, max(1, ntotal // 30)) + feature_vectors = np.zeros((ntotal, self.feature_dim), dtype=np.float32) signal_vectors = np.zeros((ntotal, self.signal_dim), dtype=np.float32) for i in range(ntotal): - full_vector = self.index.reconstruct(i) + full_vector = self.combined_index.reconstruct(i) feature_vectors[i] = full_vector[:self.feature_dim] signal_vectors[i] = full_vector[self.feature_dim:] @@ -112,54 +167,57 @@ def _build_sub_indices(self): def create_and_save_db(self): print("Initializing RAG database creation...") metadata = [] - vectors_for_index = [] + combined_vectors = [] feature_vectors = [] signal_vectors = [] - - - npy_files = list(Path(self.preprocessed_dir).glob("*.npy")) + + npy_files = list(Path(self.preprocessed_dir).glob('*.npy')) if self.args.dev: - npy_files = npy_files[:1000] - print(f"Development mode: Processing {len(npy_files)} files") + npy_files = npy_files[:300] + print(f'Development mode: Processing {len(npy_files)} files') if self.args.toy: npy_files = npy_files[:400000] - print(f"Toy mode: Processing {len(npy_files)} files") - - print(f"Found {len(npy_files)} files to process") - print("Starting feature extraction from ECG signals...") - + print(f'Toy mode: Processing {len(npy_files)} files') + + print(f'Found {len(npy_files)} files to process') + print(f'Normalization enabled: {self.args.normalized_rag_feature}') + print('Starting feature extraction from ECG signals...') + for file_path in tqdm(npy_files, desc="Extracting features"): try: data = self.fm.open_npy(file_path) - ecg = data["ecg"] - report = data["report"] - features = self.feature_extractor.extract_features(ecg) + ecg = data['ecg'] + report = data['report'] + features = self.feature_extractor.extract_rag_features(ecg).flatten() + metadata.append({ + 'report': report, + 'file_path': str(file_path), + }) - # Store vectors for different indices - feature_vector = features.flatten() - signal_vector = ecg.flatten() - combined_vector = np.hstack([feature_vector, signal_vector]) + if not self.args.normalized_rag_feature: + signal_vector = ecg.flatten() + feature_vector = features.flatten() + + else: + signal_vector = self.query_signal_lead_normalization(ecg).flatten() + feature_vector = self.query_feature_normalization(features).flatten() - feature_vectors.append(feature_vector) + combined_vector = np.hstack([feature_vector*self.feature_weight, signal_vector]) signal_vectors.append(signal_vector) - vectors_for_index.append(combined_vector) + feature_vectors.append(feature_vector) + combined_vectors.append(combined_vector) + - # Store only metadata in JSON - metadata.append({ - "report": report, - "file_path": str(file_path), - }) except Exception as e: print(f"Error processing {file_path}: {e!s}") continue - print(f"Successfully processed {len(metadata)} files") - print("Converting vectors to arrays...") - + print(f'Successfully processed {len(metadata)} files') + # Convert to arrays feature_array = np.stack(feature_vectors) signal_array = np.stack(signal_vectors) - combined_array = np.stack(vectors_for_index) + combined_array = np.stack(combined_vectors) # Calculate optimal number of clusters based on dataset size ntotal = len(combined_array) @@ -170,57 +228,58 @@ def create_and_save_db(self): # Create and save feature index print("Creating feature index...") quantizer_feature = faiss.IndexFlatL2(feature_array.shape[1]) - feature_index = faiss.IndexIVFFlat(quantizer_feature, feature_array.shape[1], nlist) - print("Training feature index...") - feature_index.train(feature_array) - print("Adding vectors to feature index...") - feature_index.add(feature_array) - feature_index.make_direct_map() - feature_path = f"./data/{self.args.base_data}/feature.index" - print(f"Saving feature index to {feature_path}...") - faiss.write_index(feature_index, feature_path) - print("Feature index saved successfully!") - + self.feature_index = faiss.IndexIVFFlat(quantizer_feature, feature_array.shape[1], nlist) + print('Training feature index...') + self.feature_index.train(feature_array) + print('Adding vectors to feature index...') + self.feature_index.add(feature_array) + self.feature_index.make_direct_map() + feature_path = f"./data/{self.args.base_data}/feature_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" + print(f'Saving feature index to {feature_path}...') + faiss.write_index(self.feature_index, feature_path) + print('Feature index saved successfully!') + # Create and save signal index print("Creating signal index...") quantizer_signal = faiss.IndexFlatL2(signal_array.shape[1]) - signal_index = faiss.IndexIVFFlat(quantizer_signal, signal_array.shape[1], nlist) - print("Training signal index...") - signal_index.train(signal_array) - print("Adding vectors to signal index...") - signal_index.add(signal_array) - signal_index.make_direct_map() - signal_path = f"./data/{self.args.base_data}/signal.index" - print(f"Saving signal index to {signal_path}...") - faiss.write_index(signal_index, signal_path) - print("Signal index saved successfully!") - + self.signal_index = faiss.IndexIVFFlat(quantizer_signal, signal_array.shape[1], nlist) + print('Training signal index...') + self.signal_index.train(signal_array) + print('Adding vectors to signal index...') + self.signal_index.add(signal_array) + self.signal_index.make_direct_map() + signal_path = f"./data/{self.args.base_data}/signal_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" + print(f'Saving signal index to {signal_path}...') + faiss.write_index(self.signal_index, signal_path) + print('Signal index saved successfully!') + # Create and save combined index print("Creating combined index...") quantizer_combined = faiss.IndexFlatL2(combined_array.shape[1]) - self.index = faiss.IndexIVFFlat(quantizer_combined, combined_array.shape[1], nlist) - print("Training combined index...") - self.index.train(combined_array) - print("Adding vectors to combined index...") - self.index.add(combined_array) - self.index.make_direct_map() - combined_path = f"./data/{self.args.base_data}/combined.index" - print(f"Saving combined index to {combined_path}...") - faiss.write_index(self.index, combined_path) - print("Combined index saved successfully!") - + self.combined_index = faiss.IndexIVFFlat(quantizer_combined, combined_array.shape[1], nlist) + print('Training combined index...') + self.combined_index.train(combined_array) + print('Adding vectors to combined index...') + self.combined_index.add(combined_array) + self.combined_index.make_direct_map() + combined_path = f"./data/{self.args.base_data}/combined_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" + print(f'Saving combined index to {combined_path}...') + faiss.write_index(self.combined_index, combined_path) + print('Combined index saved successfully!') + # Save metadata JSON metadata_path = f"./data/{self.args.base_data}/rag_metadata.json" print(f"Saving metadata to {metadata_path}...") self.fm.save_json(metadata, metadata_path) - print("Metadata saved successfully!") - - print("RAG database creation completed successfully!") - print(f"Total samples: {len(metadata)}") - print(f"Feature dimension: {feature_array.shape[1]}") - print(f"Signal dimension: {signal_array.shape[1]}") - print(f"Combined dimension: {combined_array.shape[1]}") - + print('Metadata saved successfully!') + + + print('RAG database creation completed successfully!') + print(f'Total samples: {len(metadata)}') + print(f'Feature dimension: {feature_array.shape[1]}') + print(f'Signal dimension: {signal_array.shape[1]}') + print(f'Combined dimension: {combined_array.shape[1]}') + return metadata def search_similar(self, query_features=None, query_signal=None, k=5, mode="signal",nprobe=10): @@ -240,85 +299,113 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode="sign self.feature_index.nprobe = nprobe query_features = query_features.reshape(1, self.feature_dim) distances, indices = self.feature_index.search(query_features, k) - original_indices = [int(self.index_mapping[idx]) for idx in indices[0]] - - elif mode == "signal": + original_indices = indices[0] + elif mode == 'signal': self.signal_index.nprobe = nprobe query_signal = query_signal.reshape(1, -1) distances, indices = self.signal_index.search(query_signal, k) - original_indices = [int(self.signal_mapping[idx]) for idx in indices[0]] - + original_indices = indices[0] + else: # combined mode - self.index.nprobe = nprobe - query_combined = np.hstack([query_features, query_signal]) - query_combined = query_combined.reshape(1, -1) - distances, indices = self.index.search(query_combined, k) - original_indices = [int(idx) for idx in indices[0]] - - + self.combined_index.nprobe = nprobe + query_features = query_features.reshape(1, self.feature_dim) + query_signal = query_signal.reshape(1, -1) + query_combined = np.hstack([query_features*self.feature_weight, query_signal]).reshape(1, -1) + distances, indices = self.combined_index.search(query_combined, k) + original_indices = indices[0] # Prepare results using reconstructed vectors from index results = {} for i, (dist, idx) in enumerate(zip(distances[0], original_indices)): - full_vector = self.index.reconstruct(int(idx)) - features = full_vector[:self.feature_dim] - signal = full_vector[self.feature_dim:] + file_path = self.file_paths[idx] + signal=self.fm.open_npy(file_path)['ecg'] + features=self.feature_extractor.extract_rag_features(signal) result_dict = { - "signal": signal, - "feature": features, - "report": self.reports[idx], - "distance": float(dist), - "file_path": self.file_paths[idx], + 'signal': signal, + 'feature': features, + 'report': self.reports[idx], + 'distance': float(dist), + 'file_path': file_path } results[i] = result_dict return results - - def format_search(self, results, retrieved_information="combined"): - results = self.filter_results(results) + + def format_search(self, results, retrieved_information='combined'): + if retrieved_information not in ['feature', 'report', 'combined']: + raise ValueError("retrieved_information must be 'feature', 'report', or 'combined'") + # results = self.filter_results(results) output = f"The following is the top {len(results)} retrieved ECGs and their corresponding " # Adjust the description based on retrieved_information - if retrieved_information == "feature": - output += "features. Utilize this information to further enhance your response.\n\n" - elif retrieved_information == "report": - output += "diagnosis. Utilize this information to further enhance your response.\n\n" + if retrieved_information == 'feature': + output += "features. Utilize this information to further enhance your response.\n\nThe lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" + elif retrieved_information == 'report': + output += "diagnosis. Utilize this information to further enhance your response. \n\n" else: # combined - output += "features and diagnosis. Utilize this information to further enhance your response.\n\n" - + output += "features and diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" + for idx, res in results.items(): # Filter out entries where all feature values are zero if np.all(np.array(res["feature"]) == 0): continue - output += f"Retrieved ECG {idx+1}\n" + if self.args.dev: + output+=f"Distance: {res['distance']}\n" + # Include feature information based on retrieved_information + if retrieved_information in ['feature', 'combined']: + output += "Feature Information:\n" + + # Organize features by feature type across all leads + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx + feature_values.append(round(float(res['feature'][feature_pos]), 6)) + output += f"{feature_name}: {feature_values}\n" + output += "\n" # Include diagnosis information based on retrieved_information if retrieved_information in ["report", "combined"]: output += "Diagnosis Information:\n" + # output +="--------------------------" output += f"{res['report']}\n\n" - - # Include feature information based on retrieved_information - if retrieved_information in ["feature", "combined"]: - output += "Feature Information:\n" - # Zip through feature names and feature values to format each line. - for feature_name, feature_value in zip(self.ecg_feature_list, res["feature"]): - output += f"{feature_name}: {round(float(feature_value), 6)!s}\n" - output += "\n" return output - + + def convert_features_to_structured(self, feature_array): + """ + Convert a flat feature array into a formatted string organized by feature type. + + Args: + feature_array: numpy array of shape (228,) containing RAG features for 12 leads + + Returns: + formatted_string: formatted string with feature names and arrays of 12 values + """ + if len(feature_array) != self.feature_dim: + raise ValueError(f"Expected {self.feature_dim} features, got {len(feature_array)}") + + formatted_output = "" + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx + feature_values.append(round(float(feature_array[feature_pos]), 6)) + formatted_output += f"{feature_name}: {feature_values}\n" + + return formatted_output + def filter_results(self, results): filtered_results = {} count = 0 for idx, res in results.items(): - # Check if more than x% of values are exactly zero or if the sum is too small - feature_array = np.array(res["feature"]) + feature_array = np.array(res['feature']) zero_percentage = np.sum(np.abs(feature_array) < 1e-3) / len(feature_array) - total_magnitude = np.sum(np.abs(feature_array)) + # total_magnitude = np.sum(np.abs(feature_array)) - # Filter out entries that are mostly zeros or have very low total magnitude - if zero_percentage > 0.5 or total_magnitude < 0.5: + if zero_percentage > 0.6: continue filtered_results[count] = res @@ -327,32 +414,40 @@ def filter_results(self, results): def test_search(self): self.preprocessed_dir = f"./data/{self.args.base_data}/preprocessed_{self.args.seg_len}_{self.args.target_sf}" - npy_files = list(Path(self.preprocessed_dir).glob("*.npy")) - random_idx = np.random.randint(0, len(npy_files)) - query_signal = self.fm.open_npy(npy_files[random_idx])["ecg"] - print("query_signal", query_signal.shape) + rng = np.random.RandomState(42) + npy_files = list(Path(self.preprocessed_dir).glob('*.npy')) + random_idx = rng.randint(0, len(npy_files)) + query_signal = self.fm.open_npy(npy_files[random_idx])['ecg'] + query_report = self.fm.open_npy(npy_files[random_idx])['report'] + print('query_report: /n', query_report) + print('query_signal', query_signal.shape) + # Flatten the signal to match the expected dimensions - query_signal = query_signal.flatten() + query_signal_flat = query_signal.flatten() + start_time = time.time() # Use retrieval_base parameter to determine search mode - retrieval_base = getattr(self.args, "retrieval_base", "signal") - if retrieval_base == "feature": + retrieval_base = getattr(self.args, 'retrieval_base', 'combined') + if retrieval_base == 'feature': # Extract features for feature-based search - # Need to reshape back to 2D for feature extraction - query_signal_2d = query_signal.reshape(12, -1) - features = self.feature_extractor.extract_features(query_signal_2d) - results = self.search_similar(query_features=features, k=10, mode="feature") - elif retrieval_base == "combined": + features = self.feature_extractor.extract_rag_features(query_signal) + if self.args.normalized_rag_feature: + features = self.query_feature_normalization(features) + results = self.search_similar(query_features=features, k=3, mode='feature') + elif retrieval_base == 'combined': # Extract features for combined search - # Need to reshape back to 2D for feature extraction - query_signal_2d = query_signal.reshape(12, -1) - features = self.feature_extractor.extract_features(query_signal_2d) - results = self.search_similar(query_features=features, query_signal=query_signal, k=10, mode="combined") + features = self.feature_extractor.extract_rag_features(query_signal) + if self.args.normalized_rag_feature: + features = self.query_feature_normalization(features) + query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() + results = self.search_similar(query_features=features, query_signal=query_signal_flat, k=3, mode='combined') else: # signal mode (default) - results = self.search_similar(query_signal=query_signal, k=10, mode="signal") - - formatted_results = self.format_search(results, retrieved_information=getattr(self.args, "retrieved_information", "combined")) + if self.args.normalized_rag_feature: + query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() + results = self.search_similar(query_signal=query_signal_flat, k=3, mode='signal') + + formatted_results = self.format_search(results, retrieved_information=getattr(self.args, 'retrieved_information', 'combined')) print(formatted_results) end_time = time.time() print(f"Search time: {end_time - start_time:.2f} seconds") diff --git a/ecg_bench/utils/training_utils.py b/ecg_bench/utils/training_utils.py index 65121ef..42ec931 100644 --- a/ecg_bench/utils/training_utils.py +++ b/ecg_bench/utils/training_utils.py @@ -34,6 +34,17 @@ def __init__(self, args, fm, viz, device, ecg_tokenizer_utils=None): self.args, self.fm, self.viz, self.device = args, fm, viz, device self.ecg_tokenizer_utils = ecg_tokenizer_utils self.cache_dir = "../.huggingface" + + def split_dataset(self, data, train_ratio=0.7): + data = np.array(data) + n_samples = len(data) + indices = np.random.permutation(n_samples) + n_train = int(n_samples * train_ratio) + train_indices = indices[:n_train] + test_indices = indices[n_train:] + train_data = [data[i] for i in train_indices] + test_data = [data[i] for i in test_indices] + return train_data, test_data def save_config(self): args_dict = {k: v for k, v in vars(self.args).items() if not k.startswith("_")} diff --git a/transformers b/transformers index 51f94ea..241c04d 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit 51f94ea06d19a6308c61bbb4dc97c40aabd12bad +Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc