This repository contains the official implementation of the paper:
Prompt4Trust: A Reinforcement Learning Prompt Augmentation Framework for Clinically-Aligned Confidence Calibration in Multimodal Large Language Models
Anita Kriz*, Elizabeth Laura Janes*, Xing Shen*, Tal Arbel
*Equal contribution
IEEE/CVF International Conference on Computer Vision 2025 Workshop CVAMD
Paper (arXiv preprint)
Multimodal large language models (MLLMs) show great potential for healthcare applications, but their clinical deployment is challenged by prompt sensitivity and overconfident incorrect responses. To improve trustworthiness in safety-critical settings, we introduce Prompt4Trust, the first reinforcement learning framework for prompt augmentation focused on confidence calibration in MLLMs. A lightweight LLM is trained to generate context-aware auxiliary prompts that guide a downstream MLLM to produce predictions with confidence scores that better reflect true accuracy. By prioritizing clinically meaningful calibration, Prompt4Trust enhances both reliability and task performance, achieving state-of-the-art results on the PMC-VQA benchmark while enabling efficient zero-shot generalization to larger MLLMs.
Make sure you have at least 4 NVIDIA GPUs with adequate memory (memory requirement depends on the scale of the LLM/MLLM you want to use) if you wish to use open-source downstream task MLLMs.
We recommend to download the open-source LLMs/MLLMs using huggingface-cli
before you start (make sure you obtained relevant permissions/agreement to download the models from Hugging Face):
huggingface-cli login
huggingface-cli download {REPO_NAME} --local-dir {SAVE_FOLDER} --local-dir-use-symlinks False
For example, the {REPO_NAME}
can be Qwen/Qwen2.5-1.5B-Instruct
and {SAVE_FOLDER}
can be /usr/local/data/Qwen2.5-1.5B-Instruct
. The downloaded model will be saved in the specified folder {SAVE_FOLDER}
.
It is recommended to use a virtual environment (e.g., venv
) to avoid package conflicts. Here we assume you are using venv
as your virtual environment. If you are using conda
, please adjust the commands accordingly.
git clone https://github.com/xingbpshen/prompt4trust.git
cd prompt4trust/
pip install -r requirements.txt
Create a data/
folder under the project root prompt4trust/
:
mkdir data/
Download the dataset from the PMC-VQA repository here. Put all files in the newly created data/
folder.
Our training split can be generated using dataset/gen_train.py
(with modification to data path) or can be downloaded here (coming soon).
The config files are located at config/
. You can modify the parameters according to your needs. The default config file pmcvqa.yml
is for the PMC-VQA dataset.
Here are some important parameters you may want to modify:
resources.cache_dir
This is where vLLM and other python packages will be cached. Make sure you have enough space.resources.policy_cuda
This is a string of CUDA devices (e.g.,"3,4"
or"3"
) used for the policy update/training. Make sure you have enough memory on these devices.resources.action_cuda
This is a string of CUDA devices used for the TRL with vLLM serving to sample "actions" (in the context of reinforcement learning). Make sure you have enough memory on these devices.resources.downstream_cuda
This is a string of CUDA devices used for the downstream MLLM (to obtain reward). Make sure you have enough memory on these devices.model.policy
This is the model name. You can use any repository name supported by Hugging Face or a path to a local model (e.g.,"Qwen/Qwen2.5-1.5B-Instruct"
or"/usr/local/data/Qwen2.5-1.5B-Instruct"
).model.downstream
This is the model name. You can use any repository name supported by Hugging Face or a path to a local model.
Please note that resources.policy_cuda
, resources.action_cuda
, and resources.downstream_cuda
must not include any overlapping device to avoid CUDA initialization error.
To enable TRL with vLLM serving, we need to start 2 servers: one for the policy model (to sample action) and one for the downstream LLM to calculate reward.
These servers will be started automatically so you do not need to do anything now.
By default, the policy model will be trained with GRPO using TRL support. Run the following command to start training:
python main.py --config {DATASET}.yml --log_folder {LOG_FOLDER} --trial_name {TRIAL_NAME} --train --ni
Running the above command once will start:
- 2 detached subprocesses for vLLMs, each corresponding to one of the servers. You can observe the GPU memory usage increasing in the terminal. You can use
nvidia-smi
to check the GPU memory usage for your specified CUDA devicesresources.action_cuda
andresources.downstream_cuda
. - 1 foreground engine subprocess for TRL, which will be responsible for the training of the policy model. You can observe the GPU memory usage (on your specified CUDA devices
resources.policy_cuda
) increasing in the terminal.
Runtime related logs will be saved in {LOG_FOLDER}/{TRIAL_NAME}/
folder.
Run the following command to evaluate the trained policy model:
python main.py --config {DATASET}.yml --log_folder {LOG_FOLDER} --trial_name {TRIAL_NAME} --test --ni
This work was supported in part by the Natural Sciences and Engineering Research Council of Canada, in part by the Canadian Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs Program, in part by the Mila - Quebec Artificial Intelligence Institute, in part by the compute resources provided by Mila (mila.quebec), in part by the Mila-Google Research Grant, in part by the Fonds de recherche du Québec, in part by the Canada First Research Excellence Fund, awarded to the Healthy Brains, Healthy Lives initiative at McGill University, and in part by the Department of Electrical and Computer Engineering at McGill University.
Please raise a GitHub issue or email us at [email protected] (with the email subject starting with "[Prompt4Trust]") if you have any question or encounter any issue.