import os import datetime import wandb import torch import json import numpy as np from copy import deepcopy from time import time import torch.nn.functional as F import torch.distributed as dist from torch.distributed import ReduceOp from torch.nn.utils.clip_grad import clip_grad_value_ from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts from datasets.utils import get_datasets_media from datasets.dataloader import MetaLoader from utils.dist import is_main_process, get_rank, get_world_size from utils.logger import setup_wandb, log_dict_to_wandb from .retrieval_utils import evaluation_wrapper import glog as logger def run_epoch( model, train_dataloaders, # expert_tokenizer, # enc_dec_tokenizer, optimizer, epoch, global_step, visdial_step, avsd_step, nextqa_step, device, scheduler, scaler, config ): model.train() media_types = list(train_dataloaders.keys()) log_freq = config['log_freq'] # metric_logger = MetricLogger(delimiter=' ') # metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}')) # metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}")) loss_names = ['loss_' + k for k in config['loss_dict'].keys()] # for l in loss_names: # for m in media_types: # metric_logger.add_meter( # f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}") # ) # header = '{} | Epoch = {}'.format(config['stage'], epoch) model_without_ddp = model if config['distributed']: model_without_ddp = model.module for k in train_dataloaders: train_dataloaders[k].sampler.set_epoch(epoch) # if len(train_dataloaders) == 1: # train_dataloader = list(train_dataloaders.values())[0] # else: train_dataloader = MetaLoader(name2loader=train_dataloaders) log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n' log_text_template += '[Loss] tot = {:.4f} | gen = {:.4f} \n' log_text_template += '[Other] lr = {:.6f} | iter_time = {:.2f} | eta = {}\n' # iterator = metric_logger.log_every(train_dataloader, log_freq, header) local_step = 0 # vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list for media_type, batch in train_dataloader: vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3] start = time() vis = vis.to(device) with torch.cuda.amp.autocast(enabled=config.fp16): loss_dict = model(vis, caption, history, answer, media_type) loss = sum(loss_dict.values()) loss = loss / config['accum_grad_every'] scaler.scale(loss).backward() # Perfrom gradient clipping: unscale --> clip if config['clip_grad_value'] > 0: # scaler.unscale_(optimizer) clip_grad_value_(model.parameters(), config['clip_grad_value']) if local_step % config.accum_grad_every == 0: scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() time_iter = time() - start eta = (len(train_dataloader) - local_step - 1) * time_iter eta = str(datetime.timedelta(seconds=eta)) # log log_dict_visdial = {} log_dict_avsd = {} log_dict_nextqa = {} log_dict_rest = {} for loss_name in loss_names: value = loss_dict[loss_name] value = value if isinstance(value, float) else value.item() # metric_logger.update(**{f"{media_type}/{loss_name}": value}) if media_type == 'visdial': log_dict_visdial[f"train/{media_type}/{loss_name}"] = value elif media_type == 'avsd': log_dict_avsd[f"train/{media_type}/{loss_name}"] = value elif media_type == 'nextqa': log_dict_nextqa[f"train/{media_type}/{loss_name}"] = value log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"] if is_main_process() and local_step % log_freq == 0 and local_step % config['accum_grad_every'] == 0: log_dict_rest['train/other/step'] = global_step if media_type == 'visdial': log_dict_visdial['train/visdial/step'] = visdial_step log_dict = log_dict_visdial elif media_type == 'avsd': log_dict_avsd['train/avsd/step'] = avsd_step log_dict = log_dict_avsd elif media_type == 'nextqa': log_dict_nextqa['train/nextqa/step'] = nextqa_step log_dict = log_dict_nextqa log_text = log_text_template.format( epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, loss.item(), log_dict[f'train/{media_type}/loss_gen'], log_dict_rest['train/other/lr'], time_iter, eta ) logger.info(log_text) if config['wandb_enabled']: wandb.log(log_dict_rest) wandb.log(log_dict) if media_type == 'visdial': visdial_step += 1 elif media_type == 'avsd': avsd_step += 1 elif media_type == 'nextqa': nextqa_step += 1 local_step += 1 global_step += 1 return global_step, visdial_step, avsd_step, nextqa_step # if is_main_process() and local_step % config['log_model_outputs_every'] == 0 and config['log_model_outputs']: # predictions = [] # labels = [] # probs = F.softmax(logits, dim=-1) # preds = torch.topk(probs, 1)[1].squeeze(-1) # preds = preds.tolist() # lm_labels_list = label_ids['input_ids'].tolist() # lm_labels_list = [[s for s in label if s != 1] for label in lm_labels_list] # # reponses = '' # # labels = '' # model_pred_text = '' # for pred, label in zip(preds, lm_labels_list): # predictions.append('\n' + 'Pred: ' + tokenizer_enc_dec.decode(pred) + '\n') # labels.append('\n' + 'GT: ' + tokenizer_enc_dec.decode(label) + '\n') # if len(predictions) < 4: # predictions = predictions[:4] # labels = labels[:4] # for label, pred in zip(labels, predictions): # model_pred_text += label + pred # model_pred_text += "---------------------" # logger.info(model_pred_text) # # output['reponses'] = reponses # # output['gt'] = labels def eval(model, val_dataloader, device, epoch, config): model.eval() log_text_template = '\n' + '-' * 25 + '\n[Val Epoch {}][Iter. {}/{}][Media-type {}]\n' # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' log_text_template += '[Losses] gen = {:.4f} \n' # cum_loss_stc = 0 # cum_loss_stm = 0 # cum_loss_vcc = 0 # cum_loss_vcm = 0 # cum_loss_vhc = 0 # cum_loss_vhm = 0 # cum_loss_chc = 0 # cum_loss_chm = 0 # cum_loss_mlm = 0 cum_loss_gen = 0 cum_loss_tot = 0 val_step = 0 media_type = val_dataloader.dataset.medium if is_main_process(): start_time = time() # for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: for batch in val_dataloader: vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3] vis = vis.to(device) with torch.cuda.amp.autocast(enabled=config.fp16): with torch.no_grad(): # loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type) # loss_dict, _ = model(vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, media_type) loss_dict = model(vis, caption, history, answer, media_type) # loss_dict = model(vis, cap_ids, hist_ids, ques_ids, label_ids, media_type) loss = sum(loss_dict.values()) # loss_stc = loss_dict['loss_stc'] # loss_stm = loss_dict['loss_stm'] # loss_vcc = loss_dict['loss_vcc'] # loss_vcm = loss_dict['loss_vcm'] # loss_vhc = loss_dict['loss_vhc'] # loss_vhm = loss_dict['loss_vhm'] # loss_chc = loss_dict['loss_chc'] # loss_chm = loss_dict['loss_chm'] # loss_mlm = loss_dict['loss_mlm'] loss_gen = loss_dict['loss_gen'] if config['distributed']: dist.all_reduce(loss, op=ReduceOp.AVG) # if config.loss_dict['stc'] != 0: # dist.all_reduce(loss_stc, op=ReduceOp.AVG) # if config.loss_dict['stm'] != 0: # dist.all_reduce(loss_stm, op=ReduceOp.AVG) # if config.loss_dict['vcc'] != 0: # dist.all_reduce(loss_vcc, op=ReduceOp.AVG) # if config.loss_dict['vcm'] != 0: # dist.all_reduce(loss_vcm, op=ReduceOp.AVG) # if config.loss_dict['vhc'] != 0: # dist.all_reduce(loss_vhc, op=ReduceOp.AVG) # if config.loss_dict['vhm'] != 0: # dist.all_reduce(loss_vhm, op=ReduceOp.AVG) # if config.loss_dict['chc'] != 0: # dist.all_reduce(loss_chc, op=ReduceOp.AVG) # if config.loss_dict['chm'] != 0: # dist.all_reduce(loss_chm, op=ReduceOp.AVG) # if config.loss_dict['mlm'] != 0: # dist.all_reduce(loss_mlm, op=ReduceOp.AVG) if config.loss_dict['gen'] != 0: dist.all_reduce(loss_gen, op=ReduceOp.AVG) if is_main_process(): cum_loss_tot += loss.item() # cum_loss_stc += loss_stc.item() # cum_loss_stm += loss_stm.item() # cum_loss_vcc += loss_vcc.item() # cum_loss_vcm += loss_vcm.item() # cum_loss_vhc += loss_vhc.item() # cum_loss_vhm += loss_vhm.item() # cum_loss_chc += loss_chc.item() # cum_loss_chm += loss_chm.item() # cum_loss_mlm += loss_mlm.item() cum_loss_gen += loss_gen.item() if val_step % config.log_freq == 0: # log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n' # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' log_text = log_text_template.format( epoch, val_step, len(val_dataloader), media_type, # loss_vcc, loss_vcm, loss_stc, loss_stm, 0, # loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen ) logger.info(log_text) # logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format( # epoch, val_step, len(val_dataloader), gen_loss, loss # )) val_step += 1 if config['distributed']: dist.barrier() if is_main_process(): duration = time() - start_time cum_loss_tot /= len(val_dataloader) # cum_loss_stc /= len(val_dataloader) # cum_loss_stm /= len(val_dataloader) # cum_loss_vcc /= len(val_dataloader) # cum_loss_vcm /= len(val_dataloader) # cum_loss_vhc /= len(val_dataloader) # cum_loss_vhm /= len(val_dataloader) # cum_loss_chc /= len(val_dataloader) # cum_loss_chm /= len(val_dataloader) # cum_loss_mlm /= len(val_dataloader) cum_loss_gen /= len(val_dataloader) logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format( datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot )) loss_dict = { # 'stc': cum_loss_stc, # 'stm': cum_loss_stm, # 'vcc': cum_loss_vcc, # 'vcm': cum_loss_vcm, # 'vhc': cum_loss_vhc, # 'vhm': cum_loss_vhm, # 'chc': cum_loss_chc, # 'chm': cum_loss_chm, # 'mlm': cum_loss_mlm, 'gen': cum_loss_gen, 'tot': cum_loss_tot } return loss_dict def ft_avsd( model, model_without_ddp, train_dataloaders, val_dataloaders, optimizer, global_step, visdial_step, avsd_step, nextqa_step, scheduler, scaler, start_epoch, config ): if is_main_process() and config['wandb_enabled']: run = setup_wandb(config) setup_seed(config['seed'] + get_rank()) # device = torch.device('cuda:{}'.format(config['gpu'])) device = config.device # expert_tokenizer = model_without_ddp.expert_tokenizer # enc_dec_tokenizer = model_without_ddp.enc_dec_tokenizer if is_main_process() and config['wandb_enabled']: wandb.watch(model) best = float('inf') logger.info('[INFO] Start training...') start_time_all = time() for epoch in range(start_epoch, config['epochs']): if not config['evaluate']: if is_main_process(): start_time_epoch = time() global_step, visdial_step, avsd_step, nextqa_step = run_epoch( model, train_dataloaders, # expert_tokenizer, # enc_dec_tokenizer, optimizer, epoch, global_step, visdial_step, avsd_step, nextqa_step, device, scheduler, scaler, config ) if is_main_process(): end_time_epoch = time() epoch_time = end_time_epoch - start_time_epoch epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) logger.info(f'[INFO] Epoch took {epoch_time_str}') if not config['debugging']: with torch.cuda.amp.autocast(enabled=config['fp16']): val_res = {} for medium in val_dataloaders: res = eval( model, val_dataloaders[medium], # expert_tokenizer, # enc_dec_tokenizer, device, epoch, config ) val_res[medium] = res if is_main_process(): # Average across all datasets avg_val_res = average_dicts(val_res) # log to wandb if config.wandb_enabled: for medium in val_res: log_dict_val = {} # log_dict_val[f'val/{medium}/step'] = epoch for l in val_res[medium]: log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l] wandb.log(log_dict_val) # for p, v in eval_res.items(): # log_dict_to_wandb(v, step=global_step, prefix=p) if config.stop_key is not None and config.stop_key in avg_val_res: cur_best = avg_val_res[config.stop_key] else: # stop_key = None cur_best = best - 1 # save the last as the best # Don't save vit and llm weights as they are frozen state_dict = model_without_ddp.state_dict() if config.freeze_vit: state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k} if config.freeze_llm: state_dict = {k:v for k,v in state_dict.items() if 'llm' not in k} save_obj = { "model": state_dict, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "scaler": scaler.state_dict(), "config": config, "epoch": epoch, "global_step": global_step, "visdial_step": visdial_step, "avsd_step": avsd_step, "nextqa_step": nextqa_step } torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth")) if not config.evaluate and cur_best < best: torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth")) # eval_file = "eval_res_best.json" # eval_res.to_json(os.path.join(config.log_dir, eval_file)) best = cur_best if config.evaluate: break if config['distributed']: dist.barrier() total_time = time() - start_time_all total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info(f'[INFO] Training took {total_time_str}') if is_main_process() and config['wandb_enabled']: run.finish() def generate(model, dataloader, tag, config, gen_subset_num=None): model.eval() responses = {} # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec device = next(model.parameters()).device # Assumes all model parameters are on the same device # Generate the repsonse for each round logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) with torch.no_grad(): # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): for counter, (vis, cap, hist, ans, vis_ids) in enumerate(dataloader): start_time = time() vis = vis.to(device, non_blocking=True) is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] # First get the visual features depending on the media type with torch.cuda.amp.autocast(enabled=config.fp16): cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) if config.use_moes: if config.use_sep_spatial_temp_experts: vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) else: vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) # construct the global input tensor --> use place holder for vis features if config.use_sep_spatial_temp_experts: moe_outputs = model.moe_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None else: moe_outputs = model.moe_forward_no_sep_spatial_temporal( vis_embed, vis_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds']) cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) else: cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids)) hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids)) vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) if config.llm_family in ['llama', 'mistral']: bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id bos_embeds = model.text_embedding(bos) bos_mask = cap_mask[:, :1] inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if is_vid: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if config.use_moes: if not config.drop_vis_features: if config.use_sep_spatial_temp_experts: if is_vid: inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) decoded_ids = model.llm.generate( inputs_embeds=inputs_embeds, do_sample=False, top_p=config.top_p, temperature=config.temperature, num_beams=config.beam_depth, length_penalty=config.length_penalty, max_length=config.max_generation_length, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id, # use_cache=True ) response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids] for vis_id, response in zip(vis_ids, response_batch): responses[vis_id] = response time_elapsed = int(time() - start_time) print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) # Create a file with all responses with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f: test_data = json.load(f) test_dialogs = deepcopy(test_data['dialogs']) # Filter the predicted dialogs test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) for i, dialog in enumerate(test_dialogs): vid_id = dialog['image_id'] gen_response = responses[vid_id] round_num_to_answer = len(dialog['dialog'])-1 assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' dialog['dialog'][round_num_to_answer]['answer'] = gen_response test_dialogs[i] = dialog # Log the file file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty']) if gen_subset_num is not None: file_name += f'-part_{gen_subset_num}' file_name = f'{tag}_' + file_name output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json') with open(output_path, 'w') as f: json.dump({'dialogs': test_dialogs}, f, indent=4) logger.info('Results logged to {}'.format(output_path)) # Switch back to training mode model.train() def generate_visdial(model, dataloader, tag, config, gen_subset_num=None): model.eval() responses = {} # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec device = next(model.parameters()).device # Assumes all model parameters are on the same device # Generate the repsonse for each round logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) with torch.no_grad(): # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): for counter, (vis, cap, hist, ans, vis_ids, d_rounds) in enumerate(dataloader): start_time = time() vis = vis.to(device, non_blocking=True) is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] # First get the visual features depending on the media type with torch.cuda.amp.autocast(enabled=config.fp16): # construct the global input tensor --> use place holder for vis features cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) if config.use_moes: if config.use_sep_spatial_temp_experts: vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) else: vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) if config.use_sep_spatial_temp_experts: moe_outputs = model.moe_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None else: moe_outputs = model.moe_forward_no_sep_spatial_temporal( vis_embed, vis_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds']) cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) else: cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids)) hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids)) vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) if config.llm_family in ['llama', 'mistral']: bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id bos_embeds = model.text_embedding(bos) bos_mask = cap_mask[:, :1] inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if is_vid: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if config.use_moes: if not config.drop_vis_features: if config.use_sep_spatial_temp_experts: if is_vid: inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_mask, attention_mask], dim=1) decoded_ids = model.llm.generate( inputs_embeds=inputs_embeds, do_sample=False, top_p=config.top_p, temperature=config.temperature, num_beams=config.beam_depth, length_penalty=config.length_penalty, max_length=config.max_generation_length, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id, # use_cache=True ) response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids] for vis_id, d_round, response in zip(vis_ids.tolist(), d_rounds.tolist(), response_batch): responses[str(vis_id) + '_' + str(d_round)] = response time_elapsed = time() - start_time print('Generating resonse {} / {} -- eta = {} '.format(counter + 1, len(dataloader), str(datetime.timedelta(seconds=time_elapsed * (len(dataloader)-counter))) )) # # Create a file with all responses # with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f: # test_data = json.load(f) # test_dialogs = deepcopy(test_data['dialogs']) # # Filter the predicted dialogs # test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) # for i, dialog in enumerate(test_dialogs): # vid_id = dialog['image_id'] # gen_response = responses[vid_id] # round_num_to_answer = len(dialog['dialog'])-1 # assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' # dialog['dialog'][round_num_to_answer]['answer'] = gen_response # test_dialogs[i] = dialog # Log the file file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty']) if gen_subset_num is not None: file_name += f'-part_{gen_subset_num}' file_name = f'{tag}_' + file_name output_path = os.path.join(config['output_dir_visdial'], file_name + '.json') with open(output_path, 'w') as f: json.dump(responses, f, indent=4) logger.info('Results logged to {}'.format(output_path)) # Switch back to training mode model.train() def generate_nextqa(model, dataloader, tag, config, gen_subset_num=None): model.eval() responses = {} # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec device = next(model.parameters()).device # Assumes all model parameters are on the same device # Generate the repsonse for each round logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) with torch.no_grad(): # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): for counter, (vis, cap, hist, _, vid_ids, qid) in enumerate(dataloader): start_time = time() vis = vis.to(device, non_blocking=True) is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] vid_id = vid_ids[0] qid = qid[0] if vid_id not in responses: responses[vid_id] = {} # First get the visual features depending on the media type with torch.cuda.amp.autocast(enabled=config.fp16): vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) # construct the global input tensor --> use place holder for vis features cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) moe_outputs = model.moe_forward( vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, hist_ids, hist_mask, is_vid, device ) spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) if config.llm_family in ['llama', 'mistral']: bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id bos_embeds = model.text_embedding(bos) bos_mask = cap_mask[:, :1] inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if is_vid: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) else: inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) if is_vid: inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) else: inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) decoded_ids = model.llm.generate( inputs_embeds=inputs_embeds, do_sample=False, top_p=config.top_p, temperature=config.temperature, num_beams=config.beam_depth, length_penalty=config.length_penalty, max_length=config.max_generation_length, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id, # use_cache=True ) response = model.tokenizer.decode(decoded_ids[0], skip_special_tokens=True) responses[vid_id][qid] = response # for vis_id, response in zip(vis_ids, response_batch): # responses[vis_id] = response time_elapsed = int(time() - start_time) print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) # Create a file with all responses file_name = 'results_nextqa_beam_depth_{}'.format(config['beam_depth']) if gen_subset_num is not None: file_name += f'-part_{gen_subset_num}' file_name = f'{tag}_' + file_name output_path = os.path.join(config['output_dir_nextqa'], file_name + '.json') with open(output_path, 'w') as f: json.dump(responses, f, indent=4) print('Results logged to {}'.format(output_path)) print(os.getcwd()) # Switch back to training mode model.train() def generate_enc_dec(model, dataloader, tag, config, gen_subset_num=None): model.eval() responses = {} tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec device = next(model.parameters()).device # Assumes all model parameters are on the same device # Generate the repsonse for each round logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) with torch.no_grad(): # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): for counter, (vis, cap_ids, hist_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): start_time = time() vis = vis.to(device, non_blocking=True) for k in cap_ids: if isinstance(cap_ids[k], torch.Tensor): cap_ids[k] = cap_ids[k].to(device) for k in hist_ids: if isinstance(hist_ids[k], torch.Tensor): hist_ids[k] = hist_ids[k].to(device) # for k in ques_ids: # if isinstance(ques_ids[k], torch.Tensor): # ques_ids[k] = ques_ids[k].to(device) for k in enc_dec_input_ids: if isinstance(enc_dec_input_ids[k], torch.Tensor): enc_dec_input_ids[k] = enc_dec_input_ids[k].to(device) # response = beam_search_generation( # model, vis, cap_ids, hist_ids, ques_ids, enc_dec_input_ids, tokenizer_enc_dec, config # ) response = beam_search_generation( model, vis, cap_ids, hist_ids, enc_dec_input_ids, tokenizer_enc_dec, config ) # Decode the response response = tokenizer_enc_dec.decode(response) responses[vid_id[0]] = response # all_graphs[vid] = graphs time_elapsed = int(time() - start_time) print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) # Create a file with all responses with open(config['anno_avsd_test_{}'.format(config['dstc'])], 'r') as f: test_data = json.load(f) test_dialogs = deepcopy(test_data['dialogs']) # Filter the predicted dialogs test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) for i, dialog in enumerate(test_dialogs): vid_id = dialog['image_id'] gen_response = responses[vid_id] round_num_to_answer = len(dialog['dialog'])-1 assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' dialog['dialog'][round_num_to_answer]['answer'] = gen_response test_dialogs[i] = dialog # Log the file file_name = 'results_dstc{}_beam_depth_{}'.format(config['dstc'], config['beam_depth']) if gen_subset_num is not None: file_name += f'-part_{gen_subset_num}' file_name = f'{tag}_' + file_name output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json') with open(output_path, 'w') as f: json.dump({'dialogs': test_dialogs}, f, indent=4) logger.info('Results logged to {}'.format(output_path)) # Switch back to training mode model.train() def beam_search_generation_decoder_only(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config): # gen_ans = [bos_token] hyplist = [([], 0.0, [])] best_state = None comp_hyplist = [] # drop_caption = self.config['dstc'] == 10 # instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption) encoder_outputs = None for i in range(config['max_generation_length']): new_hyplist = [] argmin = 0 for out, lp, st in hyplist: decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0) # output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') if encoder_outputs is None: encoder_outputs = output.encoder_outputs logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token logp = F.log_softmax(logits, dim=0) lp_vec = logp.cpu().data.numpy() + lp if i >= config['min_generation_length']: new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1) comp_hyplist.append((out, new_lp)) if best_state is None or best_state < new_lp: best_state = new_lp count = 1 for o in np.argsort(lp_vec)[::-1]: # reverse the order if o in [eos_token, unk_token]: continue new_lp = lp_vec[o] if len(new_hyplist) == config['beam_depth']: if new_hyplist[argmin][1] < new_lp: new_st = deepcopy(st) new_st.append(int(o)) new_hyplist[argmin] = (out + [o], new_lp, new_st) argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] else: break else: new_st = deepcopy(st) new_st.append(int(o)) new_hyplist.append((out + [o], new_lp, new_st)) if len(new_hyplist) == config['beam_depth']: argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] count += 1 hyplist = new_hyplist if len(comp_hyplist) > 0: maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1] res = maxhyps[0][0] if res[0] == bos_token: res = res[1:] if res[-1] == eos_token: res = res[:-1] return res else: return [] # def beam_search_generation(model, vis, caption, history, ques, enc_dec_input, tokenizer_enc_dec, config): def beam_search_generation(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config): if config['enc_dec_family'] == 'flan_t5': bos_token = tokenizer_enc_dec.pad_token_id eos_token = tokenizer_enc_dec.eos_token_id else: bos_token = tokenizer_enc_dec.bos_token_id eos_token = tokenizer_enc_dec.eos_token_id unk_token = tokenizer_enc_dec.unk_token_id # gen_ans = [bos_token] hyplist = [([], 0.0, [bos_token])] best_state = None comp_hyplist = [] # drop_caption = self.config['dstc'] == 10 # instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption) encoder_outputs = None for i in range(config['max_generation_length']): new_hyplist = [] argmin = 0 for out, lp, st in hyplist: decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0) # output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') if encoder_outputs is None: encoder_outputs = output.encoder_outputs logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token logp = F.log_softmax(logits, dim=0) lp_vec = logp.cpu().data.numpy() + lp if i >= config['min_generation_length']: new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1) comp_hyplist.append((out, new_lp)) if best_state is None or best_state < new_lp: best_state = new_lp count = 1 for o in np.argsort(lp_vec)[::-1]: # reverse the order if o in [eos_token, unk_token]: continue new_lp = lp_vec[o] if len(new_hyplist) == config['beam_depth']: if new_hyplist[argmin][1] < new_lp: new_st = deepcopy(st) new_st.append(int(o)) new_hyplist[argmin] = (out + [o], new_lp, new_st) argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] else: break else: new_st = deepcopy(st) new_st.append(int(o)) new_hyplist.append((out + [o], new_lp, new_st)) if len(new_hyplist) == config['beam_depth']: argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] count += 1 hyplist = new_hyplist if len(comp_hyplist) > 0: maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1] res = maxhyps[0][0] if res[0] == bos_token: res = res[1:] if res[-1] == eos_token: res = res[:-1] return res else: return []