import datetime import logging import time import numpy as np import torch import torch.distributed as dist from einops import rearrange from models.criteria import get_sim from utils.basic import MetricLogger from utils.dist import get_rank, get_world_size logger = logging.getLogger(__name__) def extract_text_feats(texts, max_txt_l, tokenizer, model, device): num_text = len(texts) text_bs = 256 text_feats = [] text_atts = [] for i in range(0, num_text, text_bs): text = texts[i : min(num_text, i + text_bs)] text_input = tokenizer( text, padding="max_length", truncation=True, max_length=max_txt_l, return_tensors="pt", ).to(device) text_feat = model.encode_text(text_input)[0] text_feats.append(text_feat) text_atts.append(text_input.attention_mask) text_feats = torch.cat(text_feats, dim=0) text_atts = torch.cat(text_atts, dim=0) return text_feats, text_atts def extract_vision_feats(data_loader, model, device, config): image_feats_all = [] pooled_image_feats_all = [] metric_logger = MetricLogger(delimiter=" ") header = "extracting image feats" iterator = metric_logger.log_every(data_loader, 100, header) media_type = data_loader.dataset.medium for vis, _ in iterator: vis = vis.to(device, non_blocking=True) vis_feat, pooled_vis_feat = model.get_vis_enc_for_eval(vis, media_type) # if config.evaluation.eval_frame_ensemble == "concat": # default # image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() vis_feat = vis_feat.unsqueeze(1) # (bsz, 1, l, d) # else: # assert config.video_input.num_frames == 1, "only support single-frame" # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] if not config.eval_offload: image_feats_all.append(vis_feat.cpu()) pooled_image_feats_all.append(pooled_vis_feat.cpu()) else: image_feats_all.append(vis_feat) pooled_image_feats_all.append(pooled_vis_feat) image_feats_all = torch.cat(image_feats_all, dim=0) pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) return image_feats_all, pooled_image_feats_all @torch.no_grad() def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): with torch.cuda.amp.autocast(enabled=config.fp16): i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation( model, data_loader, tokenizer, device, config ) score_pairs = [ (prefix + "/", i2t_x, t2i_x), (prefix + "_emb/", i2t_emb, t2i_emb), ] res = dict() for name, i2t, t2i in score_pairs: if i2t is not None: txt2img_ids = data_loader.dataset.txt2vis img2txt_ids = data_loader.dataset.vis2txt res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids) return res @torch.no_grad() def evaluation(model, data_loader, tokenizer, device, config): model.eval() metric_logger = MetricLogger(delimiter=" ") header = "Evaluation:" dtype = torch.half if config.fp16 else torch.float media_type = data_loader.dataset.medium logger.info(f"Start evaluation for {media_type}") logger.info("Computing dual encoder features...") start_time = time.time() # this computes all features in each GPU texts = data_loader.dataset.text max_txt_l = config.max_cap_len text_feats, text_atts = extract_text_feats( texts, max_txt_l, tokenizer, model, device ) # (bsz, Lt, d), (bsz, Lt) image_feats, pooled_image_feats = extract_vision_feats( data_loader, model, device, config ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d) logger.info("Finished feature extraction") logger.info("Computing ITC scores [dot-product]") _pooled_image_feats = ( pooled_image_feats.to(device, non_blocking=True) if config.eval_offload else pooled_image_feats ) i2t_scores, t2i_scores = get_sim( model.vis_proj(_pooled_image_feats), model.cap_proj(text_feats[:, 0]) ) logger.info("Computing ITC scores [dot-product], done!") num_images = len(data_loader.dataset.vis) i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to( device, torch.float, non_blocking=True ) # computes only part of the scores at each GPU, gather at the end logger.info("Rerank dual-encoder results with cross-encoder...") num_tasks = get_world_size() rank = get_rank() # only uses the part associated with the raw eval set # compute image2text # step = num_images // num_tasks + 1 start = rank * step end = min(num_images, start + step) text_encoder = model.get_expert_encoder('vis_cap_grounding') iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") # generate score for each clip, and aggregate all clip scores for a video n_clip_per_video = 1 # ( # image_feats.shape[1] if not False else image_feats[0].shape[1] # ) # logger.info( # f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={'concat'}" # ) for i, sims in enumerate(iterator): k = min(len(sims), config.eval_k_test) topk_sim, topk_idx = sims.topk(k=k, dim=0) clip_scores = [] for clip_idx in range(n_clip_per_video): # if config.deep_fusion: # encoder_output = [ # feat[start + i, clip_idx].to(device, non_blocking=True) # for feat in image_feats # ] # else: encoder_output = ( image_feats[start + i, clip_idx].to(device, non_blocking=True) if config.eval_offload else image_feats[start + i, clip_idx] ) # (#frm*Li, d) """ original encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d) encoder_att = torch.ones( encoder_output.size()[:-1], dtype=torch.long ).to(device, non_blocking=True) output = text_encoder( encoder_embeds=text_feats[topk_idx], attention_mask=text_atts[topk_idx], encoder_hidden_states=encoder_output, encoder_attention_mask=encoder_att, return_dict=True, mode="fusion" ) itm_embeds = output.last_hidden_state[:, 0] """ # new bs = 32 # bs = config.batch_size_test.video itm_embeds = [] # if config.deep_fusion: # encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] # encoder_att = [ # torch.ones(feat.size()[:-1], dtype=torch.long).to( # device, non_blocking=True # ) # for feat in encoder_output # ] # else: encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( device, non_blocking=True ) for j in range(0, len(topk_idx), bs): output = text_encoder( encoder_embeds=text_feats[topk_idx[j : j + bs]], attention_mask=text_atts[topk_idx[j : j + bs]], encoder_hidden_states=encoder_output, encoder_attention_mask=encoder_att, return_dict=True, ) batch_itm_embeds = output.last_hidden_state[:, 0] itm_embeds.append(batch_itm_embeds) itm_embeds = torch.cat(itm_embeds, dim=0) # end new score = model.vcm_head(itm_embeds)[:, 1] clip_scores.append(score) # if len(clip_scores) == 1: score = clip_scores[0] # else: # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] # clip_scores = torch.stack(clip_scores) # (#clips, k) # if config.evaluation.eval_frame_ensemble == "mean": # score = clip_scores.mean(0) # elif config.evaluation.eval_frame_ensemble == "max": # score = clip_scores.max(0)[0] # elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp # score = torch.logsumexp(clip_scores, dim=0) # else: # raise ValueError( # "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." # ) i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) # compute text2image # num_text = len(data_loader.dataset.text) t2i_scores_x = torch.full((num_text, len(data_loader.dataset.vis)), -100.0).to( device, torch.float, non_blocking=True ) step = num_text // num_tasks + 1 start = rank * step end = min(num_text, start + step) iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") # generate score for each clip, and aggregate all clip scores for a video n_clip_per_video = 1 # ( # image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] # ) for i, sims in enumerate(iterator): k = min(len(sims), config.eval_k_test) topk_sim, topk_idx = sims.topk(k=k, dim=0) # topk_idx = clip_scores = [] for clip_idx in range(n_clip_per_video): """old encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \ if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx] encoder_att = torch.ones( encoder_output.size()[:-1], dtype=torch.long ).to(device, non_blocking=True) output = text_encoder( encoder_embeds=text_feats[start+i].repeat(k, 1, 1), attention_mask=text_atts[start+i].repeat(k, 1), encoder_hidden_states=encoder_output, encoder_attention_mask=encoder_att, return_dict=True, mode="fusion" ) itm_embeds = output.last_hidden_state[:, 0] """ # new bs = 32 # bs = config.batch_size_test.video itm_embeds = [] for j in range(0, len(topk_idx), bs): # if config.deep_fusion: # encoder_output = [ # feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) # for feat in image_feats # ] # encoder_att = [ # torch.ones(feat.size()[:-1], dtype=torch.long).to( # device, non_blocking=True # ) # for feat in encoder_output # ] # else: encoder_output = ( image_feats[topk_idx[j : j + bs], clip_idx].to( device, non_blocking=True ) if config.eval_offload else image_feats[topk_idx[j : j + bs], clip_idx] ) encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( device, non_blocking=True ) repeat_n = ( encoder_output.shape[0] # if not config.deep_fusion # else encoder_output[0].shape[0] ) output = text_encoder( encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1), attention_mask=text_atts[start + i].repeat(repeat_n, 1), encoder_hidden_states=encoder_output, encoder_attention_mask=encoder_att, return_dict=True, # mode="fusion", ) batch_itm_embeds = output.last_hidden_state[:, 0] itm_embeds.append(batch_itm_embeds) itm_embeds = torch.cat(itm_embeds, dim=0) # end new score = model.vcm_head(itm_embeds)[:, 1] clip_scores.append(score) # if len(clip_scores) == 1: score = clip_scores[0] # else: # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] # clip_scores = torch.stack(clip_scores) # (#clips, k) # if config.evaluation.eval_frame_ensemble == "mean": # score = clip_scores.mean(0) # elif config.evaluation.eval_frame_ensemble == "max": # score = clip_scores.max(0)[0] # elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp # score = torch.logsumexp(clip_scores, dim=0) # else: # raise ValueError( # "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." # ) t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) if config.distributed: # gether across GPUs dist.barrier() dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info(f"Evaluation time {total_time_str}") return ( i2t_scores_x.cpu().numpy(), t2i_scores_x.cpu().numpy(), i2t_scores.cpu().numpy(), i2t_scores.T.cpu().numpy(), ) @torch.no_grad() def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): # Images->Text ranks = np.zeros(scores_i2t.shape[0]) for index, score in enumerate(scores_i2t): inds = np.argsort(score)[::-1] # Score gt_txt_ids = img2txt[index] if isinstance(gt_txt_ids, int): ranks[index] = np.where(inds == gt_txt_ids)[0][0] else: rank = 1e20 for i in gt_txt_ids: tmp = np.where(inds == i)[0][0] if tmp < rank: rank = tmp ranks[index] = rank # Compute metrics tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) # Text->Images ranks = np.zeros(scores_t2i.shape[0]) for index, score in enumerate(scores_t2i): inds = np.argsort(score)[::-1] gt_img_ids = txt2img[index] if isinstance(gt_img_ids, int): ranks[index] = np.where(inds == gt_img_ids)[0][0] else: # list, used in the case each caption has multiple GT images # Score rank = 1e20 for i in gt_img_ids: tmp = np.where(inds == i)[0][0] if tmp < rank: rank = tmp ranks[index] = rank # Compute metrics ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) tr_mean = (tr1 + tr5 + tr10) / 3 ir_mean = (ir1 + ir5 + ir10) / 3 r_mean = (tr_mean + ir_mean) / 2 eval_result = { "txt_r1": tr1, "txt_r5": tr5, "txt_r10": tr10, "txt_r_mean": tr_mean, "vis_r1": ir1, "vis_r5": ir5, "vis_r10": ir10, "vis_r_mean": ir_mean, "r_mean": r_mean, } eval_result = {k: round(v, 2) for k, v in eval_result.items()} return eval_result