@@ -105,11 +105,11 @@ def cache_name(self, clip_input: CLIPInput):
105105 return hashlib .md5 (clip_input .tobytes ()).hexdigest ()
106106 return clip_input
107107
108- def get_clip_features (self , inputs : list [CLIPInput ]):
108+ def get_clip_features (self , inputs : list [CLIPInput ], progress_bar = True ):
109109 missing = [clip_input for clip_input in inputs if self .cache_name (clip_input ) not in self .cached_texts ]
110110
111111 if len (missing ) > 0 :
112- pbar_disable = len (missing ) <= self .batch_size
112+ pbar_disable = not progress_bar or len (missing ) <= self .batch_size
113113 pbar = tqdm (total = len (inputs ), initial = len (inputs ) - len (missing ),
114114 desc = "Computing CLIP features" , disable = pbar_disable )
115115
@@ -122,7 +122,8 @@ def get_clip_features(self, inputs: list[CLIPInput]):
122122
123123 pbar .close ()
124124
125- texts = tqdm (inputs , desc = "Loading features cache" , disable = len (inputs ) <= self .batch_size )
125+ texts = tqdm (inputs , desc = "Loading features cache" ,
126+ disable = not progress_bar or len (inputs ) <= self .batch_size )
126127 cached_features = [self .cache [self .cache_name (text )].cpu () for text in texts ]
127128 features = torch .stack (cached_features )
128129
@@ -131,9 +132,10 @@ def get_clip_features(self, inputs: list[CLIPInput]):
131132 def score (self , hypothesis : CLIPInput , reference : CLIPInput ) -> float :
132133 return self .score_all ([hypothesis ], [reference ])[0 ][0 ]
133134
134- def score_all (self , hypotheses : list [CLIPInput ], references : list [CLIPInput ]) -> list [list [float ]]:
135- hyp_features = self .get_clip_features (hypotheses )
136- ref_features = self .get_clip_features (references )
135+ def score_all (self , hypotheses : list [CLIPInput ], references : list [CLIPInput ],
136+ progress_bar = True ) -> list [list [float ]]:
137+ hyp_features = self .get_clip_features (hypotheses , progress_bar )
138+ ref_features = self .get_clip_features (references , progress_bar )
137139
138140 similarities = []
139141 for hyp_feature in hyp_features :
0 commit comments