Skip to content

Commit 1609aad

Browse files
committed
GPU and batching support for captioning
1 parent 0ac68e7 commit 1609aad

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

fastdup/captions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import cv2
55

66

7-
def generate_labels(filenames, modelname='automatic', batch_size=8):
7+
def generate_labels(filenames, model_name='automatic', device = -1, batch_size=8):
88
'''
99
This function generates captions for a given set of images, and takes the following arguments:
1010
- filenames: the list of images passed to the function
@@ -14,7 +14,9 @@ def generate_labels(filenames, modelname='automatic', batch_size=8):
1414
- BLIP-2: 'blip2'
1515
- BLIP: 'blip'
1616
- batch_size: the size of image batches to caption (default: 8)
17+
- device: whether to use a GPU (default: -1, CPU only ; set to 0 for GPU)
1718
'''
19+
# use GPU if device is specified
1820

1921
# confirm necessary dependencies are installed, and import them
2022
try:
@@ -39,12 +41,11 @@ def generate_labels(filenames, modelname='automatic', batch_size=8):
3941
'blip': "Salesforce/blip-image-captioning-large"
4042
}
4143

42-
model = models[modelname]
44+
model = models[model_name]
4345

4446
# generate captions
4547
try:
46-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47-
captioner = pipeline("image-to-text", model=model, device=device, batch_size=batch_size)
48+
captioner = pipeline("image-to-text", model=model, device=device, batch_size=batch_size, device=device)
4849

4950
captions = []
5051
for image_path in tqdm(filenames):

fastdup/fastdup_controller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ def _verify_fastdup_run_args(self, input_dir, work_dir, df_annot, subset, data_t
12601260
else:
12611261
assert False, f"Wrong data type {data_type}"
12621262

1263-
def caption(self, model_name='automatic', subset: list = None, vqa_prompt: str = None, kwargs=None) -> pd.DataFrame:
1263+
def caption(self, model_name='automatic', device = -1, batch_size: int = 8, subset: list = None, vqa_prompt: str = None, kwargs=None) -> pd.DataFrame:
12641264
if not self._fastdup_applied:
12651265
raise RuntimeError('Fastdup was not applied yet, call run() first')
12661266

@@ -1272,7 +1272,7 @@ def caption(self, model_name='automatic', subset: list = None, vqa_prompt: str =
12721272

12731273
if model_name in FD.CAPTION_MODEL_NAMES:
12741274
from fastdup.captions import generate_labels
1275-
df['caption'] = generate_labels(df['filename'], model_name)
1275+
df['caption'] = generate_labels(df['filename'], model_name, device, batch_size)
12761276
elif model_name == FD.VQA_MODEL1_NAME:
12771277
from fastdup.captions import generate_vqa_labels
12781278
df['caption'] = generate_vqa_labels(df['filename'], vqa_prompt, kwargs)

fastdup/galleries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def swap_dataframe(subdf, cols):
7474

7575

7676

77-
def find_label(get_label_func, df, in_col, out_col, vqa_prompt: str = None, kwargs=None):
77+
def find_label(get_label_func, df, in_col, out_col, vqa_prompt: str = None, device = -1, kwargs=None):
7878

7979

8080
if (get_label_func is not None):
@@ -87,7 +87,7 @@ def find_label(get_label_func, df, in_col, out_col, vqa_prompt: str = None, kwar
8787
df[out_col] = df['label']
8888
elif get_label_func in CAPTION_MODEL_NAMES:
8989
from fastdup.captions import generate_labels
90-
df[out_col] = generate_labels(df[in_col], get_label_func)
90+
df[out_col] = generate_labels(df[in_col], get_label_func, device)
9191
elif get_label_func == VQA_MODEL1_NAME:
9292
from fastdup.captions import generate_vqa_labels
9393
df[out_col] = generate_vqa_labels(df[in_col], vqa_prompt, kwargs)

0 commit comments

Comments
 (0)