435 lines
16 KiB
Python
435 lines
16 KiB
Python
import datetime
|
|
import logging
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from einops import rearrange
|
|
|
|
from models.criteria import get_sim
|
|
from utils.basic import MetricLogger
|
|
from utils.dist import get_rank, get_world_size
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def extract_text_feats(texts, max_txt_l, tokenizer, model, device):
|
|
num_text = len(texts)
|
|
text_bs = 256
|
|
text_feats = []
|
|
text_atts = []
|
|
|
|
for i in range(0, num_text, text_bs):
|
|
text = texts[i : min(num_text, i + text_bs)]
|
|
text_input = tokenizer(
|
|
text,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=max_txt_l,
|
|
return_tensors="pt",
|
|
).to(device)
|
|
|
|
text_feat = model.encode_text(text_input)[0]
|
|
text_feats.append(text_feat)
|
|
text_atts.append(text_input.attention_mask)
|
|
|
|
text_feats = torch.cat(text_feats, dim=0)
|
|
text_atts = torch.cat(text_atts, dim=0)
|
|
return text_feats, text_atts
|
|
|
|
|
|
def extract_vision_feats(data_loader, model, device, config):
|
|
image_feats_all = []
|
|
pooled_image_feats_all = []
|
|
metric_logger = MetricLogger(delimiter=" ")
|
|
header = "extracting image feats"
|
|
iterator = metric_logger.log_every(data_loader, 100, header)
|
|
media_type = data_loader.dataset.medium
|
|
for vis, _ in iterator:
|
|
vis = vis.to(device, non_blocking=True)
|
|
vis_feat, pooled_vis_feat = model.get_vis_enc_for_eval(vis, media_type)
|
|
# if config.evaluation.eval_frame_ensemble == "concat": # default
|
|
# image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
|
|
vis_feat = vis_feat.unsqueeze(1) # (bsz, 1, l, d)
|
|
# else:
|
|
# assert config.video_input.num_frames == 1, "only support single-frame"
|
|
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
|
if not config.eval_offload:
|
|
image_feats_all.append(vis_feat.cpu())
|
|
pooled_image_feats_all.append(pooled_vis_feat.cpu())
|
|
else:
|
|
image_feats_all.append(vis_feat)
|
|
pooled_image_feats_all.append(pooled_vis_feat)
|
|
|
|
image_feats_all = torch.cat(image_feats_all, dim=0)
|
|
|
|
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
|
|
return image_feats_all, pooled_image_feats_all
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""):
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation(
|
|
model, data_loader, tokenizer, device, config
|
|
)
|
|
score_pairs = [
|
|
(prefix + "/", i2t_x, t2i_x),
|
|
(prefix + "_emb/", i2t_emb, t2i_emb),
|
|
]
|
|
res = dict()
|
|
for name, i2t, t2i in score_pairs:
|
|
if i2t is not None:
|
|
txt2img_ids = data_loader.dataset.txt2vis
|
|
img2txt_ids = data_loader.dataset.vis2txt
|
|
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
|
|
return res
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluation(model, data_loader, tokenizer, device, config):
|
|
model.eval()
|
|
|
|
metric_logger = MetricLogger(delimiter=" ")
|
|
header = "Evaluation:"
|
|
dtype = torch.half if config.fp16 else torch.float
|
|
media_type = data_loader.dataset.medium
|
|
logger.info(f"Start evaluation for {media_type}")
|
|
|
|
logger.info("Computing dual encoder features...")
|
|
start_time = time.time()
|
|
|
|
# this computes all features in each GPU
|
|
texts = data_loader.dataset.text
|
|
max_txt_l = config.max_cap_len
|
|
|
|
text_feats, text_atts = extract_text_feats(
|
|
texts, max_txt_l, tokenizer, model, device
|
|
) # (bsz, Lt, d), (bsz, Lt)
|
|
|
|
image_feats, pooled_image_feats = extract_vision_feats(
|
|
data_loader, model, device, config
|
|
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
|
|
logger.info("Finished feature extraction")
|
|
logger.info("Computing ITC scores [dot-product]")
|
|
_pooled_image_feats = (
|
|
pooled_image_feats.to(device, non_blocking=True)
|
|
if config.eval_offload
|
|
else pooled_image_feats
|
|
)
|
|
i2t_scores, t2i_scores = get_sim(
|
|
model.vis_proj(_pooled_image_feats), model.cap_proj(text_feats[:, 0])
|
|
)
|
|
logger.info("Computing ITC scores [dot-product], done!")
|
|
|
|
num_images = len(data_loader.dataset.vis)
|
|
i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to(
|
|
device, torch.float, non_blocking=True
|
|
)
|
|
|
|
# computes only part of the scores at each GPU, gather at the end
|
|
logger.info("Rerank dual-encoder results with cross-encoder...")
|
|
num_tasks = get_world_size()
|
|
rank = get_rank()
|
|
# only uses the part associated with the raw eval set
|
|
# compute image2text #
|
|
step = num_images // num_tasks + 1
|
|
start = rank * step
|
|
end = min(num_images, start + step)
|
|
|
|
text_encoder = model.get_expert_encoder('vis_cap_grounding')
|
|
|
|
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
|
|
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
|
|
|
|
# generate score for each clip, and aggregate all clip scores for a video
|
|
n_clip_per_video = 1
|
|
# (
|
|
# image_feats.shape[1] if not False else image_feats[0].shape[1]
|
|
# )
|
|
|
|
# logger.info(
|
|
# f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={'concat'}"
|
|
# )
|
|
for i, sims in enumerate(iterator):
|
|
k = min(len(sims), config.eval_k_test)
|
|
topk_sim, topk_idx = sims.topk(k=k, dim=0)
|
|
|
|
clip_scores = []
|
|
for clip_idx in range(n_clip_per_video):
|
|
# if config.deep_fusion:
|
|
# encoder_output = [
|
|
# feat[start + i, clip_idx].to(device, non_blocking=True)
|
|
# for feat in image_feats
|
|
# ]
|
|
|
|
# else:
|
|
encoder_output = (
|
|
image_feats[start + i, clip_idx].to(device, non_blocking=True)
|
|
if config.eval_offload
|
|
else image_feats[start + i, clip_idx]
|
|
) # (#frm*Li, d)
|
|
|
|
""" original
|
|
encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d)
|
|
encoder_att = torch.ones(
|
|
encoder_output.size()[:-1], dtype=torch.long
|
|
).to(device, non_blocking=True)
|
|
output = text_encoder(
|
|
encoder_embeds=text_feats[topk_idx],
|
|
attention_mask=text_atts[topk_idx],
|
|
encoder_hidden_states=encoder_output,
|
|
encoder_attention_mask=encoder_att,
|
|
return_dict=True,
|
|
mode="fusion"
|
|
)
|
|
|
|
itm_embeds = output.last_hidden_state[:, 0]
|
|
"""
|
|
|
|
# new
|
|
bs = 32
|
|
# bs = config.batch_size_test.video
|
|
itm_embeds = []
|
|
|
|
# if config.deep_fusion:
|
|
# encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
|
|
# encoder_att = [
|
|
# torch.ones(feat.size()[:-1], dtype=torch.long).to(
|
|
# device, non_blocking=True
|
|
# )
|
|
# for feat in encoder_output
|
|
# ]
|
|
# else:
|
|
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
|
|
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
|
|
device, non_blocking=True
|
|
)
|
|
|
|
for j in range(0, len(topk_idx), bs):
|
|
output = text_encoder(
|
|
encoder_embeds=text_feats[topk_idx[j : j + bs]],
|
|
attention_mask=text_atts[topk_idx[j : j + bs]],
|
|
encoder_hidden_states=encoder_output,
|
|
encoder_attention_mask=encoder_att,
|
|
return_dict=True,
|
|
)
|
|
batch_itm_embeds = output.last_hidden_state[:, 0]
|
|
itm_embeds.append(batch_itm_embeds)
|
|
itm_embeds = torch.cat(itm_embeds, dim=0)
|
|
# end new
|
|
|
|
score = model.vcm_head(itm_embeds)[:, 1]
|
|
clip_scores.append(score)
|
|
|
|
# if len(clip_scores) == 1:
|
|
score = clip_scores[0]
|
|
# else:
|
|
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
|
# clip_scores = torch.stack(clip_scores) # (#clips, k)
|
|
# if config.evaluation.eval_frame_ensemble == "mean":
|
|
# score = clip_scores.mean(0)
|
|
# elif config.evaluation.eval_frame_ensemble == "max":
|
|
# score = clip_scores.max(0)[0]
|
|
# elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
|
|
# score = torch.logsumexp(clip_scores, dim=0)
|
|
# else:
|
|
# raise ValueError(
|
|
# "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
|
|
# )
|
|
|
|
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
|
|
|
|
# compute text2image #
|
|
num_text = len(data_loader.dataset.text)
|
|
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.vis)), -100.0).to(
|
|
device, torch.float, non_blocking=True
|
|
)
|
|
|
|
step = num_text // num_tasks + 1
|
|
start = rank * step
|
|
end = min(num_text, start + step)
|
|
|
|
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
|
|
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
|
|
# generate score for each clip, and aggregate all clip scores for a video
|
|
n_clip_per_video = 1
|
|
# (
|
|
# image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
|
|
# )
|
|
for i, sims in enumerate(iterator):
|
|
k = min(len(sims), config.eval_k_test)
|
|
topk_sim, topk_idx = sims.topk(k=k, dim=0)
|
|
# topk_idx =
|
|
clip_scores = []
|
|
for clip_idx in range(n_clip_per_video):
|
|
|
|
"""old
|
|
encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \
|
|
if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx]
|
|
encoder_att = torch.ones(
|
|
encoder_output.size()[:-1], dtype=torch.long
|
|
).to(device, non_blocking=True)
|
|
output = text_encoder(
|
|
encoder_embeds=text_feats[start+i].repeat(k, 1, 1),
|
|
attention_mask=text_atts[start+i].repeat(k, 1),
|
|
encoder_hidden_states=encoder_output,
|
|
encoder_attention_mask=encoder_att,
|
|
return_dict=True,
|
|
mode="fusion"
|
|
)
|
|
|
|
itm_embeds = output.last_hidden_state[:, 0]
|
|
"""
|
|
|
|
# new
|
|
bs = 32
|
|
# bs = config.batch_size_test.video
|
|
itm_embeds = []
|
|
for j in range(0, len(topk_idx), bs):
|
|
|
|
# if config.deep_fusion:
|
|
# encoder_output = [
|
|
# feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True)
|
|
# for feat in image_feats
|
|
# ]
|
|
# encoder_att = [
|
|
# torch.ones(feat.size()[:-1], dtype=torch.long).to(
|
|
# device, non_blocking=True
|
|
# )
|
|
# for feat in encoder_output
|
|
# ]
|
|
# else:
|
|
encoder_output = (
|
|
image_feats[topk_idx[j : j + bs], clip_idx].to(
|
|
device, non_blocking=True
|
|
)
|
|
if config.eval_offload
|
|
else image_feats[topk_idx[j : j + bs], clip_idx]
|
|
)
|
|
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
|
|
device, non_blocking=True
|
|
)
|
|
|
|
repeat_n = (
|
|
encoder_output.shape[0]
|
|
# if not config.deep_fusion
|
|
# else encoder_output[0].shape[0]
|
|
)
|
|
output = text_encoder(
|
|
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
|
|
attention_mask=text_atts[start + i].repeat(repeat_n, 1),
|
|
encoder_hidden_states=encoder_output,
|
|
encoder_attention_mask=encoder_att,
|
|
return_dict=True,
|
|
# mode="fusion",
|
|
)
|
|
|
|
batch_itm_embeds = output.last_hidden_state[:, 0]
|
|
itm_embeds.append(batch_itm_embeds)
|
|
|
|
itm_embeds = torch.cat(itm_embeds, dim=0)
|
|
# end new
|
|
|
|
score = model.vcm_head(itm_embeds)[:, 1]
|
|
clip_scores.append(score)
|
|
|
|
# if len(clip_scores) == 1:
|
|
score = clip_scores[0]
|
|
# else:
|
|
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
|
# clip_scores = torch.stack(clip_scores) # (#clips, k)
|
|
# if config.evaluation.eval_frame_ensemble == "mean":
|
|
# score = clip_scores.mean(0)
|
|
# elif config.evaluation.eval_frame_ensemble == "max":
|
|
# score = clip_scores.max(0)[0]
|
|
# elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
|
|
# score = torch.logsumexp(clip_scores, dim=0)
|
|
# else:
|
|
# raise ValueError(
|
|
# "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
|
|
# )
|
|
|
|
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
|
|
|
|
if config.distributed:
|
|
# gether across GPUs
|
|
dist.barrier()
|
|
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
|
|
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
|
|
|
|
total_time = time.time() - start_time
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
logger.info(f"Evaluation time {total_time_str}")
|
|
|
|
return (
|
|
i2t_scores_x.cpu().numpy(),
|
|
t2i_scores_x.cpu().numpy(),
|
|
i2t_scores.cpu().numpy(),
|
|
i2t_scores.T.cpu().numpy(),
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
|
|
# Images->Text
|
|
ranks = np.zeros(scores_i2t.shape[0])
|
|
for index, score in enumerate(scores_i2t):
|
|
inds = np.argsort(score)[::-1]
|
|
# Score
|
|
gt_txt_ids = img2txt[index]
|
|
if isinstance(gt_txt_ids, int):
|
|
ranks[index] = np.where(inds == gt_txt_ids)[0][0]
|
|
else:
|
|
rank = 1e20
|
|
for i in gt_txt_ids:
|
|
tmp = np.where(inds == i)[0][0]
|
|
if tmp < rank:
|
|
rank = tmp
|
|
ranks[index] = rank
|
|
|
|
# Compute metrics
|
|
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
|
|
# Text->Images
|
|
ranks = np.zeros(scores_t2i.shape[0])
|
|
|
|
for index, score in enumerate(scores_t2i):
|
|
inds = np.argsort(score)[::-1]
|
|
gt_img_ids = txt2img[index]
|
|
if isinstance(gt_img_ids, int):
|
|
ranks[index] = np.where(inds == gt_img_ids)[0][0]
|
|
else: # list, used in the case each caption has multiple GT images
|
|
# Score
|
|
rank = 1e20
|
|
for i in gt_img_ids:
|
|
tmp = np.where(inds == i)[0][0]
|
|
if tmp < rank:
|
|
rank = tmp
|
|
ranks[index] = rank
|
|
|
|
# Compute metrics
|
|
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
|
|
tr_mean = (tr1 + tr5 + tr10) / 3
|
|
ir_mean = (ir1 + ir5 + ir10) / 3
|
|
r_mean = (tr_mean + ir_mean) / 2
|
|
|
|
eval_result = {
|
|
"txt_r1": tr1,
|
|
"txt_r5": tr5,
|
|
"txt_r10": tr10,
|
|
"txt_r_mean": tr_mean,
|
|
"vis_r1": ir1,
|
|
"vis_r5": ir5,
|
|
"vis_r10": ir10,
|
|
"vis_r_mean": ir_mean,
|
|
"r_mean": r_mean,
|
|
}
|
|
eval_result = {k: round(v, 2) for k, v in eval_result.items()}
|
|
return eval_result
|