Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/*.pt
/*.json
/*.txt
*.pyc
17 changes: 10 additions & 7 deletions RET_CLIP/eval/zeroshot_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tqdm import tqdm

import torch
import sys
sys.path.append(".")

from RET_CLIP.clip.model import convert_weights, CLIP
from RET_CLIP.clip import tokenize
Expand Down Expand Up @@ -109,8 +111,10 @@ def zero_shot_classifier(model, classnames, templates, args):
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [_preprocess_text(template(classname)) for template in templates] # format with class
texts = tokenize(texts, context_length=args.context_length).to(args.gpu) # tokenize
class_embeddings = model(None, texts)
text_ids = tokenize(texts, context_length=args.context_length).to(args.gpu) # tokenize
# model(texts) returns a tuple of (text, text_left, text_right).
# Since we provide the input image as img_l, we choose text_left as the class embeddings.
class_embeddings = model(None, None, text_ids)[1]
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
Expand All @@ -136,7 +140,7 @@ def run(model, classifier, dataloader, args):
total_targets.append(target)

# predict
image_features = model(images, None)
image_features = model(images, None, None)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = (100.0 * image_features @ classifier).softmax(dim=-1)
total_logits.append(logits)
Expand Down Expand Up @@ -215,13 +219,12 @@ def run(model, classifier, dataloader, args):
# Map model to be loaded to specified single gpu.
loc = "cuda:{}".format(args.gpu)
checkpoint = torch.load(args.resume, map_location='cpu')
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
sd = checkpoint
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
print(
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
f"=> loaded checkpoint '{args.resume}'"
)

# Compute ensembled class embeddings
Expand Down Expand Up @@ -257,7 +260,7 @@ def json_prec_dump(data, prec=6):
json.loads(json.dumps(data), parse_float=lambda x: round(float(x), prec))
)

print(logits.size())
#print(logits.size())
output_dict = {
"model_name": "CN-CLIP-" + args.vision_model,
"dataset_name": args.dataset,
Expand Down
2 changes: 1 addition & 1 deletion RET_CLIP/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn.functional as F

from RET_CLIP.clip.model import convert_state_dict
from eval_RFMiD import eval_multiLabelCls_ViT, eval_multiLabelCls_RN50
from .eval_RFMiD import eval_multiLabelCls_ViT, eval_multiLabelCls_RN50


def is_master(args):
Expand Down