1051 lines
No EOL
47 KiB
Python
1051 lines
No EOL
47 KiB
Python
import os
|
|
import datetime
|
|
import wandb
|
|
import torch
|
|
import json
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from time import time
|
|
import torch.nn.functional as F
|
|
|
|
import torch.distributed as dist
|
|
from torch.distributed import ReduceOp
|
|
from torch.nn.utils.clip_grad import clip_grad_value_
|
|
from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts
|
|
from datasets.utils import get_datasets_media
|
|
from datasets.dataloader import MetaLoader
|
|
from utils.dist import is_main_process, get_rank, get_world_size
|
|
from utils.logger import setup_wandb, log_dict_to_wandb
|
|
from .retrieval_utils import evaluation_wrapper
|
|
import glog as logger
|
|
|
|
|
|
def run_epoch(
|
|
model,
|
|
train_dataloaders,
|
|
# expert_tokenizer,
|
|
# enc_dec_tokenizer,
|
|
optimizer,
|
|
epoch,
|
|
global_step,
|
|
visdial_step,
|
|
avsd_step,
|
|
nextqa_step,
|
|
device,
|
|
scheduler,
|
|
scaler,
|
|
config
|
|
):
|
|
model.train()
|
|
media_types = list(train_dataloaders.keys())
|
|
|
|
log_freq = config['log_freq']
|
|
# metric_logger = MetricLogger(delimiter=' ')
|
|
# metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}'))
|
|
# metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}"))
|
|
|
|
loss_names = ['loss_' + k for k in config['loss_dict'].keys()]
|
|
# for l in loss_names:
|
|
# for m in media_types:
|
|
# metric_logger.add_meter(
|
|
# f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}")
|
|
# )
|
|
|
|
# header = '{} | Epoch = {}'.format(config['stage'], epoch)
|
|
|
|
model_without_ddp = model
|
|
if config['distributed']:
|
|
model_without_ddp = model.module
|
|
for k in train_dataloaders:
|
|
train_dataloaders[k].sampler.set_epoch(epoch)
|
|
|
|
# if len(train_dataloaders) == 1:
|
|
# train_dataloader = list(train_dataloaders.values())[0]
|
|
# else:
|
|
train_dataloader = MetaLoader(name2loader=train_dataloaders)
|
|
|
|
log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
|
|
log_text_template += '[Loss] tot = {:.4f} | gen = {:.4f} \n'
|
|
log_text_template += '[Other] lr = {:.6f} | iter_time = {:.2f} | eta = {}\n'
|
|
|
|
# iterator = metric_logger.log_every(train_dataloader, log_freq, header)
|
|
local_step = 0
|
|
# vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
|
|
for media_type, batch in train_dataloader:
|
|
vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3]
|
|
|
|
start = time()
|
|
vis = vis.to(device)
|
|
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
loss_dict = model(vis, caption, history, answer, media_type)
|
|
loss = sum(loss_dict.values())
|
|
loss = loss / config['accum_grad_every']
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
# Perfrom gradient clipping: unscale --> clip
|
|
if config['clip_grad_value'] > 0:
|
|
# scaler.unscale_(optimizer)
|
|
clip_grad_value_(model.parameters(), config['clip_grad_value'])
|
|
|
|
if local_step % config.accum_grad_every == 0:
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
time_iter = time() - start
|
|
eta = (len(train_dataloader) - local_step - 1) * time_iter
|
|
eta = str(datetime.timedelta(seconds=eta))
|
|
# log
|
|
log_dict_visdial = {}
|
|
log_dict_avsd = {}
|
|
log_dict_nextqa = {}
|
|
|
|
log_dict_rest = {}
|
|
|
|
for loss_name in loss_names:
|
|
value = loss_dict[loss_name]
|
|
value = value if isinstance(value, float) else value.item()
|
|
# metric_logger.update(**{f"{media_type}/{loss_name}": value})
|
|
if media_type == 'visdial':
|
|
log_dict_visdial[f"train/{media_type}/{loss_name}"] = value
|
|
elif media_type == 'avsd':
|
|
log_dict_avsd[f"train/{media_type}/{loss_name}"] = value
|
|
elif media_type == 'nextqa':
|
|
log_dict_nextqa[f"train/{media_type}/{loss_name}"] = value
|
|
|
|
log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"]
|
|
|
|
if is_main_process() and local_step % log_freq == 0 and local_step % config['accum_grad_every'] == 0:
|
|
log_dict_rest['train/other/step'] = global_step
|
|
|
|
if media_type == 'visdial':
|
|
log_dict_visdial['train/visdial/step'] = visdial_step
|
|
log_dict = log_dict_visdial
|
|
|
|
elif media_type == 'avsd':
|
|
log_dict_avsd['train/avsd/step'] = avsd_step
|
|
log_dict = log_dict_avsd
|
|
|
|
elif media_type == 'nextqa':
|
|
log_dict_nextqa['train/nextqa/step'] = nextqa_step
|
|
log_dict = log_dict_nextqa
|
|
|
|
log_text = log_text_template.format(
|
|
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, loss.item(),
|
|
log_dict[f'train/{media_type}/loss_gen'], log_dict_rest['train/other/lr'],
|
|
time_iter, eta
|
|
)
|
|
logger.info(log_text)
|
|
|
|
if config['wandb_enabled']:
|
|
wandb.log(log_dict_rest)
|
|
wandb.log(log_dict)
|
|
|
|
if media_type == 'visdial':
|
|
visdial_step += 1
|
|
elif media_type == 'avsd':
|
|
avsd_step += 1
|
|
elif media_type == 'nextqa':
|
|
nextqa_step += 1
|
|
|
|
|
|
local_step += 1
|
|
global_step += 1
|
|
|
|
return global_step, visdial_step, avsd_step, nextqa_step
|
|
|
|
|
|
# if is_main_process() and local_step % config['log_model_outputs_every'] == 0 and config['log_model_outputs']:
|
|
# predictions = []
|
|
# labels = []
|
|
# probs = F.softmax(logits, dim=-1)
|
|
# preds = torch.topk(probs, 1)[1].squeeze(-1)
|
|
# preds = preds.tolist()
|
|
# lm_labels_list = label_ids['input_ids'].tolist()
|
|
# lm_labels_list = [[s for s in label if s != 1] for label in lm_labels_list]
|
|
# # reponses = ''
|
|
# # labels = ''
|
|
# model_pred_text = ''
|
|
# for pred, label in zip(preds, lm_labels_list):
|
|
# predictions.append('\n' + 'Pred: ' + tokenizer_enc_dec.decode(pred) + '\n')
|
|
# labels.append('\n' + 'GT: ' + tokenizer_enc_dec.decode(label) + '\n')
|
|
|
|
# if len(predictions) < 4:
|
|
# predictions = predictions[:4]
|
|
# labels = labels[:4]
|
|
|
|
|
|
# for label, pred in zip(labels, predictions):
|
|
# model_pred_text += label + pred
|
|
# model_pred_text += "---------------------"
|
|
# logger.info(model_pred_text)
|
|
|
|
# # output['reponses'] = reponses
|
|
# # output['gt'] = labels
|
|
|
|
|
|
|
|
|
|
def eval(model, val_dataloader, device, epoch, config):
|
|
|
|
model.eval()
|
|
|
|
log_text_template = '\n' + '-' * 25 + '\n[Val Epoch {}][Iter. {}/{}][Media-type {}]\n'
|
|
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
|
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
|
|
|
log_text_template += '[Losses] gen = {:.4f} \n'
|
|
|
|
# cum_loss_stc = 0
|
|
# cum_loss_stm = 0
|
|
# cum_loss_vcc = 0
|
|
# cum_loss_vcm = 0
|
|
# cum_loss_vhc = 0
|
|
# cum_loss_vhm = 0
|
|
# cum_loss_chc = 0
|
|
# cum_loss_chm = 0
|
|
# cum_loss_mlm = 0
|
|
cum_loss_gen = 0
|
|
cum_loss_tot = 0
|
|
val_step = 0
|
|
media_type = val_dataloader.dataset.medium
|
|
if is_main_process():
|
|
start_time = time()
|
|
|
|
# for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
|
|
for batch in val_dataloader:
|
|
|
|
vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3]
|
|
|
|
vis = vis.to(device)
|
|
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
with torch.no_grad():
|
|
# loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type)
|
|
# loss_dict, _ = model(vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, media_type)
|
|
loss_dict = model(vis, caption, history, answer, media_type)
|
|
|
|
# loss_dict = model(vis, cap_ids, hist_ids, ques_ids, label_ids, media_type)
|
|
loss = sum(loss_dict.values())
|
|
# loss_stc = loss_dict['loss_stc']
|
|
# loss_stm = loss_dict['loss_stm']
|
|
# loss_vcc = loss_dict['loss_vcc']
|
|
# loss_vcm = loss_dict['loss_vcm']
|
|
# loss_vhc = loss_dict['loss_vhc']
|
|
# loss_vhm = loss_dict['loss_vhm']
|
|
# loss_chc = loss_dict['loss_chc']
|
|
# loss_chm = loss_dict['loss_chm']
|
|
# loss_mlm = loss_dict['loss_mlm']
|
|
loss_gen = loss_dict['loss_gen']
|
|
|
|
if config['distributed']:
|
|
dist.all_reduce(loss, op=ReduceOp.AVG)
|
|
# if config.loss_dict['stc'] != 0:
|
|
# dist.all_reduce(loss_stc, op=ReduceOp.AVG)
|
|
# if config.loss_dict['stm'] != 0:
|
|
# dist.all_reduce(loss_stm, op=ReduceOp.AVG)
|
|
# if config.loss_dict['vcc'] != 0:
|
|
# dist.all_reduce(loss_vcc, op=ReduceOp.AVG)
|
|
# if config.loss_dict['vcm'] != 0:
|
|
# dist.all_reduce(loss_vcm, op=ReduceOp.AVG)
|
|
# if config.loss_dict['vhc'] != 0:
|
|
# dist.all_reduce(loss_vhc, op=ReduceOp.AVG)
|
|
# if config.loss_dict['vhm'] != 0:
|
|
# dist.all_reduce(loss_vhm, op=ReduceOp.AVG)
|
|
# if config.loss_dict['chc'] != 0:
|
|
# dist.all_reduce(loss_chc, op=ReduceOp.AVG)
|
|
# if config.loss_dict['chm'] != 0:
|
|
# dist.all_reduce(loss_chm, op=ReduceOp.AVG)
|
|
# if config.loss_dict['mlm'] != 0:
|
|
# dist.all_reduce(loss_mlm, op=ReduceOp.AVG)
|
|
if config.loss_dict['gen'] != 0:
|
|
dist.all_reduce(loss_gen, op=ReduceOp.AVG)
|
|
|
|
if is_main_process():
|
|
cum_loss_tot += loss.item()
|
|
# cum_loss_stc += loss_stc.item()
|
|
# cum_loss_stm += loss_stm.item()
|
|
# cum_loss_vcc += loss_vcc.item()
|
|
# cum_loss_vcm += loss_vcm.item()
|
|
# cum_loss_vhc += loss_vhc.item()
|
|
# cum_loss_vhm += loss_vhm.item()
|
|
# cum_loss_chc += loss_chc.item()
|
|
# cum_loss_chm += loss_chm.item()
|
|
# cum_loss_mlm += loss_mlm.item()
|
|
cum_loss_gen += loss_gen.item()
|
|
|
|
if val_step % config.log_freq == 0:
|
|
# log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n'
|
|
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
|
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
|
log_text = log_text_template.format(
|
|
epoch, val_step, len(val_dataloader), media_type,
|
|
# loss_vcc, loss_vcm, loss_stc, loss_stm, 0,
|
|
# loss_vhc, loss_vhm, loss_chc, loss_chm,
|
|
loss_gen
|
|
)
|
|
|
|
logger.info(log_text)
|
|
# logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format(
|
|
# epoch, val_step, len(val_dataloader), gen_loss, loss
|
|
# ))
|
|
val_step += 1
|
|
|
|
if config['distributed']:
|
|
dist.barrier()
|
|
|
|
if is_main_process():
|
|
duration = time() - start_time
|
|
|
|
cum_loss_tot /= len(val_dataloader)
|
|
# cum_loss_stc /= len(val_dataloader)
|
|
# cum_loss_stm /= len(val_dataloader)
|
|
# cum_loss_vcc /= len(val_dataloader)
|
|
# cum_loss_vcm /= len(val_dataloader)
|
|
# cum_loss_vhc /= len(val_dataloader)
|
|
# cum_loss_vhm /= len(val_dataloader)
|
|
# cum_loss_chc /= len(val_dataloader)
|
|
# cum_loss_chm /= len(val_dataloader)
|
|
# cum_loss_mlm /= len(val_dataloader)
|
|
cum_loss_gen /= len(val_dataloader)
|
|
|
|
logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format(
|
|
datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot
|
|
))
|
|
loss_dict = {
|
|
# 'stc': cum_loss_stc,
|
|
# 'stm': cum_loss_stm,
|
|
# 'vcc': cum_loss_vcc,
|
|
# 'vcm': cum_loss_vcm,
|
|
# 'vhc': cum_loss_vhc,
|
|
# 'vhm': cum_loss_vhm,
|
|
# 'chc': cum_loss_chc,
|
|
# 'chm': cum_loss_chm,
|
|
# 'mlm': cum_loss_mlm,
|
|
'gen': cum_loss_gen,
|
|
'tot': cum_loss_tot
|
|
}
|
|
return loss_dict
|
|
|
|
|
|
def ft_avsd(
|
|
model,
|
|
model_without_ddp,
|
|
train_dataloaders,
|
|
val_dataloaders,
|
|
optimizer,
|
|
global_step,
|
|
visdial_step,
|
|
avsd_step,
|
|
nextqa_step,
|
|
scheduler,
|
|
scaler,
|
|
start_epoch,
|
|
config
|
|
):
|
|
if is_main_process() and config['wandb_enabled']:
|
|
run = setup_wandb(config)
|
|
setup_seed(config['seed'] + get_rank())
|
|
# device = torch.device('cuda:{}'.format(config['gpu']))
|
|
device = config.device
|
|
# expert_tokenizer = model_without_ddp.expert_tokenizer
|
|
# enc_dec_tokenizer = model_without_ddp.enc_dec_tokenizer
|
|
|
|
if is_main_process() and config['wandb_enabled']:
|
|
wandb.watch(model)
|
|
|
|
best = float('inf')
|
|
|
|
logger.info('[INFO] Start training...')
|
|
start_time_all = time()
|
|
for epoch in range(start_epoch, config['epochs']):
|
|
if not config['evaluate']:
|
|
if is_main_process():
|
|
start_time_epoch = time()
|
|
|
|
global_step, visdial_step, avsd_step, nextqa_step = run_epoch(
|
|
model,
|
|
train_dataloaders,
|
|
# expert_tokenizer,
|
|
# enc_dec_tokenizer,
|
|
optimizer,
|
|
epoch,
|
|
global_step,
|
|
visdial_step,
|
|
avsd_step,
|
|
nextqa_step,
|
|
device,
|
|
scheduler,
|
|
scaler,
|
|
config
|
|
)
|
|
if is_main_process():
|
|
end_time_epoch = time()
|
|
epoch_time = end_time_epoch - start_time_epoch
|
|
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
|
|
logger.info(f'[INFO] Epoch took {epoch_time_str}')
|
|
|
|
if not config['debugging']:
|
|
with torch.cuda.amp.autocast(enabled=config['fp16']):
|
|
val_res = {}
|
|
|
|
for medium in val_dataloaders:
|
|
res = eval(
|
|
model,
|
|
val_dataloaders[medium],
|
|
# expert_tokenizer,
|
|
# enc_dec_tokenizer,
|
|
device,
|
|
epoch,
|
|
config
|
|
)
|
|
val_res[medium] = res
|
|
|
|
if is_main_process():
|
|
# Average across all datasets
|
|
avg_val_res = average_dicts(val_res)
|
|
# log to wandb
|
|
if config.wandb_enabled:
|
|
for medium in val_res:
|
|
log_dict_val = {}
|
|
# log_dict_val[f'val/{medium}/step'] = epoch
|
|
for l in val_res[medium]:
|
|
log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l]
|
|
wandb.log(log_dict_val)
|
|
# for p, v in eval_res.items():
|
|
# log_dict_to_wandb(v, step=global_step, prefix=p)
|
|
if config.stop_key is not None and config.stop_key in avg_val_res:
|
|
cur_best = avg_val_res[config.stop_key]
|
|
else: # stop_key = None
|
|
cur_best = best - 1 # save the last as the best
|
|
|
|
# Don't save vit and llm weights as they are frozen
|
|
state_dict = model_without_ddp.state_dict()
|
|
if config.freeze_vit:
|
|
state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k}
|
|
|
|
if config.freeze_llm:
|
|
state_dict = {k:v for k,v in state_dict.items() if 'llm' not in k}
|
|
|
|
save_obj = {
|
|
"model": state_dict,
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict(),
|
|
"scaler": scaler.state_dict(),
|
|
"config": config,
|
|
"epoch": epoch,
|
|
"global_step": global_step,
|
|
"visdial_step": visdial_step,
|
|
"avsd_step": avsd_step,
|
|
"nextqa_step": nextqa_step
|
|
}
|
|
torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth"))
|
|
|
|
if not config.evaluate and cur_best < best:
|
|
torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth"))
|
|
# eval_file = "eval_res_best.json"
|
|
# eval_res.to_json(os.path.join(config.log_dir, eval_file))
|
|
best = cur_best
|
|
|
|
if config.evaluate:
|
|
break
|
|
if config['distributed']:
|
|
dist.barrier()
|
|
|
|
total_time = time() - start_time_all
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
logger.info(f'[INFO] Training took {total_time_str}')
|
|
|
|
if is_main_process() and config['wandb_enabled']:
|
|
run.finish()
|
|
|
|
|
|
def generate(model, dataloader, tag, config, gen_subset_num=None):
|
|
|
|
model.eval()
|
|
responses = {}
|
|
# tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec
|
|
device = next(model.parameters()).device # Assumes all model parameters are on the same device
|
|
# Generate the repsonse for each round
|
|
logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader)))
|
|
with torch.no_grad():
|
|
# for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader):
|
|
for counter, (vis, cap, hist, ans, vis_ids) in enumerate(dataloader):
|
|
|
|
start_time = time()
|
|
vis = vis.to(device, non_blocking=True)
|
|
is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa']
|
|
|
|
# First get the visual features depending on the media type
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None)
|
|
hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None)
|
|
|
|
if config.use_moes:
|
|
if config.use_sep_spatial_temp_experts:
|
|
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid)
|
|
else:
|
|
vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
|
|
|
|
# construct the global input tensor --> use place holder for vis features
|
|
|
|
if config.use_sep_spatial_temp_experts:
|
|
moe_outputs = model.moe_forward(
|
|
vis_embed_spatial, vis_spatial_mask,
|
|
vis_embed_temporal, vis_temporal_mask,
|
|
cap_ids, cap_mask,
|
|
hist_ids, hist_mask,
|
|
is_vid, device
|
|
)
|
|
spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds'])
|
|
temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None
|
|
|
|
else:
|
|
moe_outputs = model.moe_forward_no_sep_spatial_temporal(
|
|
vis_embed, vis_mask,
|
|
cap_ids, cap_mask,
|
|
hist_ids, hist_mask,
|
|
is_vid, device
|
|
)
|
|
vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds'])
|
|
|
|
cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds'])
|
|
hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds'])
|
|
else:
|
|
cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids))
|
|
hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids))
|
|
vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
|
|
|
|
|
|
if config.llm_family in ['llama', 'mistral']:
|
|
bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id
|
|
bos_embeds = model.text_embedding(bos)
|
|
bos_mask = cap_mask[:, :1]
|
|
|
|
inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1)
|
|
|
|
else:
|
|
inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
if config.use_moes:
|
|
if not config.drop_vis_features:
|
|
if config.use_sep_spatial_temp_experts:
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
|
|
|
|
decoded_ids = model.llm.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
do_sample=False,
|
|
top_p=config.top_p,
|
|
temperature=config.temperature,
|
|
num_beams=config.beam_depth,
|
|
length_penalty=config.length_penalty,
|
|
max_length=config.max_generation_length,
|
|
pad_token_id=model.tokenizer.pad_token_id,
|
|
eos_token_id=model.tokenizer.eos_token_id,
|
|
# use_cache=True
|
|
)
|
|
|
|
response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids]
|
|
|
|
for vis_id, response in zip(vis_ids, response_batch):
|
|
responses[vis_id] = response
|
|
|
|
time_elapsed = int(time() - start_time)
|
|
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed))
|
|
|
|
# Create a file with all responses
|
|
with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f:
|
|
test_data = json.load(f)
|
|
test_dialogs = deepcopy(test_data['dialogs'])
|
|
# Filter the predicted dialogs
|
|
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
|
|
|
|
for i, dialog in enumerate(test_dialogs):
|
|
vid_id = dialog['image_id']
|
|
gen_response = responses[vid_id]
|
|
round_num_to_answer = len(dialog['dialog'])-1
|
|
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
|
|
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
|
|
test_dialogs[i] = dialog
|
|
|
|
# Log the file
|
|
file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty'])
|
|
if gen_subset_num is not None:
|
|
file_name += f'-part_{gen_subset_num}'
|
|
file_name = f'{tag}_' + file_name
|
|
output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json')
|
|
with open(output_path, 'w') as f:
|
|
json.dump({'dialogs': test_dialogs}, f, indent=4)
|
|
logger.info('Results logged to {}'.format(output_path))
|
|
# Switch back to training mode
|
|
model.train()
|
|
|
|
|
|
def generate_visdial(model, dataloader, tag, config, gen_subset_num=None):
|
|
|
|
model.eval()
|
|
responses = {}
|
|
# tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec
|
|
device = next(model.parameters()).device # Assumes all model parameters are on the same device
|
|
# Generate the repsonse for each round
|
|
logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader)))
|
|
with torch.no_grad():
|
|
# for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader):
|
|
for counter, (vis, cap, hist, ans, vis_ids, d_rounds) in enumerate(dataloader):
|
|
|
|
start_time = time()
|
|
vis = vis.to(device, non_blocking=True)
|
|
is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa']
|
|
|
|
# First get the visual features depending on the media type
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
# construct the global input tensor --> use place holder for vis features
|
|
cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None)
|
|
hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None)
|
|
|
|
if config.use_moes:
|
|
if config.use_sep_spatial_temp_experts:
|
|
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid)
|
|
else:
|
|
vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
|
|
|
|
|
|
if config.use_sep_spatial_temp_experts:
|
|
moe_outputs = model.moe_forward(
|
|
vis_embed_spatial, vis_spatial_mask,
|
|
vis_embed_temporal, vis_temporal_mask,
|
|
cap_ids, cap_mask,
|
|
hist_ids, hist_mask,
|
|
is_vid, device
|
|
)
|
|
spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds'])
|
|
temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None
|
|
else:
|
|
moe_outputs = model.moe_forward_no_sep_spatial_temporal(
|
|
vis_embed, vis_mask,
|
|
cap_ids, cap_mask,
|
|
hist_ids, hist_mask,
|
|
is_vid, device
|
|
)
|
|
vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds'])
|
|
|
|
cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds'])
|
|
hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds'])
|
|
else:
|
|
cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids))
|
|
hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids))
|
|
vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid)
|
|
|
|
if config.llm_family in ['llama', 'mistral']:
|
|
bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id
|
|
bos_embeds = model.text_embedding(bos)
|
|
bos_mask = cap_mask[:, :1]
|
|
|
|
inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1)
|
|
|
|
else:
|
|
inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
if config.use_moes:
|
|
if not config.drop_vis_features:
|
|
if config.use_sep_spatial_temp_experts:
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
|
|
|
|
decoded_ids = model.llm.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
do_sample=False,
|
|
top_p=config.top_p,
|
|
temperature=config.temperature,
|
|
num_beams=config.beam_depth,
|
|
length_penalty=config.length_penalty,
|
|
max_length=config.max_generation_length,
|
|
pad_token_id=model.tokenizer.pad_token_id,
|
|
eos_token_id=model.tokenizer.eos_token_id,
|
|
# use_cache=True
|
|
)
|
|
|
|
response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids]
|
|
|
|
for vis_id, d_round, response in zip(vis_ids.tolist(), d_rounds.tolist(), response_batch):
|
|
responses[str(vis_id) + '_' + str(d_round)] = response
|
|
|
|
time_elapsed = time() - start_time
|
|
print('Generating resonse {} / {} -- eta = {} '.format(counter + 1, len(dataloader), str(datetime.timedelta(seconds=time_elapsed * (len(dataloader)-counter)))
|
|
))
|
|
|
|
# # Create a file with all responses
|
|
# with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f:
|
|
# test_data = json.load(f)
|
|
# test_dialogs = deepcopy(test_data['dialogs'])
|
|
# # Filter the predicted dialogs
|
|
# test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
|
|
|
|
# for i, dialog in enumerate(test_dialogs):
|
|
# vid_id = dialog['image_id']
|
|
# gen_response = responses[vid_id]
|
|
# round_num_to_answer = len(dialog['dialog'])-1
|
|
# assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
|
|
# dialog['dialog'][round_num_to_answer]['answer'] = gen_response
|
|
# test_dialogs[i] = dialog
|
|
|
|
# Log the file
|
|
file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty'])
|
|
if gen_subset_num is not None:
|
|
file_name += f'-part_{gen_subset_num}'
|
|
file_name = f'{tag}_' + file_name
|
|
output_path = os.path.join(config['output_dir_visdial'], file_name + '.json')
|
|
with open(output_path, 'w') as f:
|
|
json.dump(responses, f, indent=4)
|
|
logger.info('Results logged to {}'.format(output_path))
|
|
# Switch back to training mode
|
|
model.train()
|
|
|
|
def generate_nextqa(model, dataloader, tag, config, gen_subset_num=None):
|
|
|
|
model.eval()
|
|
responses = {}
|
|
# tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec
|
|
device = next(model.parameters()).device # Assumes all model parameters are on the same device
|
|
# Generate the repsonse for each round
|
|
logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader)))
|
|
with torch.no_grad():
|
|
# for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader):
|
|
for counter, (vis, cap, hist, _, vid_ids, qid) in enumerate(dataloader):
|
|
|
|
start_time = time()
|
|
vis = vis.to(device, non_blocking=True)
|
|
is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa']
|
|
|
|
vid_id = vid_ids[0]
|
|
qid = qid[0]
|
|
if vid_id not in responses:
|
|
responses[vid_id] = {}
|
|
|
|
# First get the visual features depending on the media type
|
|
with torch.cuda.amp.autocast(enabled=config.fp16):
|
|
vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid)
|
|
|
|
# construct the global input tensor --> use place holder for vis features
|
|
cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None)
|
|
hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None)
|
|
|
|
moe_outputs = model.moe_forward(
|
|
vis_embed_spatial, vis_spatial_mask,
|
|
vis_embed_temporal, vis_temporal_mask,
|
|
cap_ids, cap_mask,
|
|
hist_ids, hist_mask,
|
|
is_vid, device
|
|
)
|
|
spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds'])
|
|
temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None
|
|
cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds'])
|
|
hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds'])
|
|
|
|
if config.llm_family in ['llama', 'mistral']:
|
|
bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id
|
|
bos_embeds = model.text_embedding(bos)
|
|
bos_mask = cap_mask[:, :1]
|
|
|
|
inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1)
|
|
|
|
else:
|
|
inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device)
|
|
|
|
if is_vid:
|
|
inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1)
|
|
else:
|
|
inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1)
|
|
|
|
decoded_ids = model.llm.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
do_sample=False,
|
|
top_p=config.top_p,
|
|
temperature=config.temperature,
|
|
num_beams=config.beam_depth,
|
|
length_penalty=config.length_penalty,
|
|
max_length=config.max_generation_length,
|
|
pad_token_id=model.tokenizer.pad_token_id,
|
|
eos_token_id=model.tokenizer.eos_token_id,
|
|
# use_cache=True
|
|
)
|
|
|
|
response = model.tokenizer.decode(decoded_ids[0], skip_special_tokens=True)
|
|
responses[vid_id][qid] = response
|
|
|
|
# for vis_id, response in zip(vis_ids, response_batch):
|
|
# responses[vis_id] = response
|
|
|
|
time_elapsed = int(time() - start_time)
|
|
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed))
|
|
|
|
# Create a file with all responses
|
|
file_name = 'results_nextqa_beam_depth_{}'.format(config['beam_depth'])
|
|
if gen_subset_num is not None:
|
|
file_name += f'-part_{gen_subset_num}'
|
|
file_name = f'{tag}_' + file_name
|
|
output_path = os.path.join(config['output_dir_nextqa'], file_name + '.json')
|
|
with open(output_path, 'w') as f:
|
|
json.dump(responses, f, indent=4)
|
|
print('Results logged to {}'.format(output_path))
|
|
print(os.getcwd())
|
|
# Switch back to training mode
|
|
model.train()
|
|
|
|
|
|
def generate_enc_dec(model, dataloader, tag, config, gen_subset_num=None):
|
|
|
|
model.eval()
|
|
responses = {}
|
|
tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec
|
|
device = next(model.parameters()).device # Assumes all model parameters are on the same device
|
|
# Generate the repsonse for each round
|
|
logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader)))
|
|
with torch.no_grad():
|
|
# for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader):
|
|
for counter, (vis, cap_ids, hist_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader):
|
|
|
|
start_time = time()
|
|
vis = vis.to(device, non_blocking=True)
|
|
|
|
for k in cap_ids:
|
|
if isinstance(cap_ids[k], torch.Tensor):
|
|
cap_ids[k] = cap_ids[k].to(device)
|
|
|
|
for k in hist_ids:
|
|
if isinstance(hist_ids[k], torch.Tensor):
|
|
hist_ids[k] = hist_ids[k].to(device)
|
|
|
|
# for k in ques_ids:
|
|
# if isinstance(ques_ids[k], torch.Tensor):
|
|
# ques_ids[k] = ques_ids[k].to(device)
|
|
|
|
for k in enc_dec_input_ids:
|
|
if isinstance(enc_dec_input_ids[k], torch.Tensor):
|
|
enc_dec_input_ids[k] = enc_dec_input_ids[k].to(device)
|
|
|
|
# response = beam_search_generation(
|
|
# model, vis, cap_ids, hist_ids, ques_ids, enc_dec_input_ids, tokenizer_enc_dec, config
|
|
# )
|
|
|
|
response = beam_search_generation(
|
|
model, vis, cap_ids, hist_ids, enc_dec_input_ids, tokenizer_enc_dec, config
|
|
)
|
|
|
|
# Decode the response
|
|
response = tokenizer_enc_dec.decode(response)
|
|
responses[vid_id[0]] = response
|
|
# all_graphs[vid] = graphs
|
|
time_elapsed = int(time() - start_time)
|
|
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed))
|
|
|
|
# Create a file with all responses
|
|
with open(config['anno_avsd_test_{}'.format(config['dstc'])], 'r') as f:
|
|
test_data = json.load(f)
|
|
test_dialogs = deepcopy(test_data['dialogs'])
|
|
# Filter the predicted dialogs
|
|
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
|
|
|
|
for i, dialog in enumerate(test_dialogs):
|
|
vid_id = dialog['image_id']
|
|
gen_response = responses[vid_id]
|
|
round_num_to_answer = len(dialog['dialog'])-1
|
|
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
|
|
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
|
|
test_dialogs[i] = dialog
|
|
|
|
# Log the file
|
|
file_name = 'results_dstc{}_beam_depth_{}'.format(config['dstc'], config['beam_depth'])
|
|
if gen_subset_num is not None:
|
|
file_name += f'-part_{gen_subset_num}'
|
|
file_name = f'{tag}_' + file_name
|
|
output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json')
|
|
with open(output_path, 'w') as f:
|
|
json.dump({'dialogs': test_dialogs}, f, indent=4)
|
|
logger.info('Results logged to {}'.format(output_path))
|
|
# Switch back to training mode
|
|
model.train()
|
|
|
|
|
|
def beam_search_generation_decoder_only(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config):
|
|
|
|
# gen_ans = [bos_token]
|
|
hyplist = [([], 0.0, [])]
|
|
best_state = None
|
|
comp_hyplist = []
|
|
|
|
# drop_caption = self.config['dstc'] == 10
|
|
# instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
|
|
|
|
encoder_outputs = None
|
|
|
|
for i in range(config['max_generation_length']):
|
|
new_hyplist = []
|
|
argmin = 0
|
|
for out, lp, st in hyplist:
|
|
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
|
|
|
# output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd')
|
|
output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd')
|
|
|
|
if encoder_outputs is None:
|
|
encoder_outputs = output.encoder_outputs
|
|
|
|
logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
|
logp = F.log_softmax(logits, dim=0)
|
|
lp_vec = logp.cpu().data.numpy() + lp
|
|
if i >= config['min_generation_length']:
|
|
new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1)
|
|
comp_hyplist.append((out, new_lp))
|
|
if best_state is None or best_state < new_lp:
|
|
best_state = new_lp
|
|
count = 1
|
|
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
|
if o in [eos_token, unk_token]:
|
|
continue
|
|
new_lp = lp_vec[o]
|
|
if len(new_hyplist) == config['beam_depth']:
|
|
if new_hyplist[argmin][1] < new_lp:
|
|
new_st = deepcopy(st)
|
|
new_st.append(int(o))
|
|
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
|
else:
|
|
break
|
|
else:
|
|
new_st = deepcopy(st)
|
|
new_st.append(int(o))
|
|
new_hyplist.append((out + [o], new_lp, new_st))
|
|
if len(new_hyplist) == config['beam_depth']:
|
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
|
count += 1
|
|
hyplist = new_hyplist
|
|
|
|
if len(comp_hyplist) > 0:
|
|
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
|
res = maxhyps[0][0]
|
|
if res[0] == bos_token:
|
|
res = res[1:]
|
|
if res[-1] == eos_token:
|
|
res = res[:-1]
|
|
return res
|
|
else:
|
|
return []
|
|
|
|
|
|
# def beam_search_generation(model, vis, caption, history, ques, enc_dec_input, tokenizer_enc_dec, config):
|
|
def beam_search_generation(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config):
|
|
|
|
if config['enc_dec_family'] == 'flan_t5':
|
|
bos_token = tokenizer_enc_dec.pad_token_id
|
|
eos_token = tokenizer_enc_dec.eos_token_id
|
|
else:
|
|
bos_token = tokenizer_enc_dec.bos_token_id
|
|
eos_token = tokenizer_enc_dec.eos_token_id
|
|
|
|
unk_token = tokenizer_enc_dec.unk_token_id
|
|
|
|
# gen_ans = [bos_token]
|
|
hyplist = [([], 0.0, [bos_token])]
|
|
best_state = None
|
|
comp_hyplist = []
|
|
|
|
# drop_caption = self.config['dstc'] == 10
|
|
# instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
|
|
|
|
encoder_outputs = None
|
|
|
|
for i in range(config['max_generation_length']):
|
|
new_hyplist = []
|
|
argmin = 0
|
|
for out, lp, st in hyplist:
|
|
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
|
|
|
# output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd')
|
|
output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd')
|
|
|
|
if encoder_outputs is None:
|
|
encoder_outputs = output.encoder_outputs
|
|
|
|
logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
|
logp = F.log_softmax(logits, dim=0)
|
|
lp_vec = logp.cpu().data.numpy() + lp
|
|
if i >= config['min_generation_length']:
|
|
new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1)
|
|
comp_hyplist.append((out, new_lp))
|
|
if best_state is None or best_state < new_lp:
|
|
best_state = new_lp
|
|
count = 1
|
|
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
|
if o in [eos_token, unk_token]:
|
|
continue
|
|
new_lp = lp_vec[o]
|
|
if len(new_hyplist) == config['beam_depth']:
|
|
if new_hyplist[argmin][1] < new_lp:
|
|
new_st = deepcopy(st)
|
|
new_st.append(int(o))
|
|
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
|
else:
|
|
break
|
|
else:
|
|
new_st = deepcopy(st)
|
|
new_st.append(int(o))
|
|
new_hyplist.append((out + [o], new_lp, new_st))
|
|
if len(new_hyplist) == config['beam_depth']:
|
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
|
count += 1
|
|
hyplist = new_hyplist
|
|
|
|
if len(comp_hyplist) > 0:
|
|
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
|
res = maxhyps[0][0]
|
|
if res[0] == bos_token:
|
|
res = res[1:]
|
|
if res[-1] == eos_token:
|
|
res = res[:-1]
|
|
return res
|
|
else:
|
|
return [] |