Skip to content

Commit bca6470

Browse files
committed
lint(clip): add missing argument
1 parent b22f03d commit bca6470

File tree

1 file changed

+8
-6
lines changed
  • signwriting_evaluation/metrics

1 file changed

+8
-6
lines changed

signwriting_evaluation/metrics/clip.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def cache_name(self, clip_input: CLIPInput):
105105
return hashlib.md5(clip_input.tobytes()).hexdigest()
106106
return clip_input
107107

108-
def get_clip_features(self, inputs: list[CLIPInput]):
108+
def get_clip_features(self, inputs: list[CLIPInput], progress_bar=True):
109109
missing = [clip_input for clip_input in inputs if self.cache_name(clip_input) not in self.cached_texts]
110110

111111
if len(missing) > 0:
112-
pbar_disable = len(missing) <= self.batch_size
112+
pbar_disable = not progress_bar or len(missing) <= self.batch_size
113113
pbar = tqdm(total=len(inputs), initial=len(inputs) - len(missing),
114114
desc="Computing CLIP features", disable=pbar_disable)
115115

@@ -122,7 +122,8 @@ def get_clip_features(self, inputs: list[CLIPInput]):
122122

123123
pbar.close()
124124

125-
texts = tqdm(inputs, desc="Loading features cache", disable=len(inputs) <= self.batch_size)
125+
texts = tqdm(inputs, desc="Loading features cache",
126+
disable=not progress_bar or len(inputs) <= self.batch_size)
126127
cached_features = [self.cache[self.cache_name(text)].cpu() for text in texts]
127128
features = torch.stack(cached_features)
128129

@@ -131,9 +132,10 @@ def get_clip_features(self, inputs: list[CLIPInput]):
131132
def score(self, hypothesis: CLIPInput, reference: CLIPInput) -> float:
132133
return self.score_all([hypothesis], [reference])[0][0]
133134

134-
def score_all(self, hypotheses: list[CLIPInput], references: list[CLIPInput]) -> list[list[float]]:
135-
hyp_features = self.get_clip_features(hypotheses)
136-
ref_features = self.get_clip_features(references)
135+
def score_all(self, hypotheses: list[CLIPInput], references: list[CLIPInput],
136+
progress_bar=True) -> list[list[float]]:
137+
hyp_features = self.get_clip_features(hypotheses, progress_bar)
138+
ref_features = self.get_clip_features(references, progress_bar)
137139

138140
similarities = []
139141
for hyp_feature in hyp_features:

0 commit comments

Comments
 (0)