Make code public
This commit is contained in:
commit
8e03ef1c38
49 changed files with 545354 additions and 0 deletions
488
runners/runner.py
Normal file
488
runners/runner.py
Normal file
|
@ -0,0 +1,488 @@
|
|||
import wandb
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
from collections import deque, OrderedDict
|
||||
import time
|
||||
import re
|
||||
import shutil
|
||||
import glob
|
||||
import pickle
|
||||
import gc
|
||||
import numpy as np
|
||||
import glog as log
|
||||
|
||||
import torch
|
||||
import torch.utils.data as tud
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.nn.utils import clip_grad_value_
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
if 'rank' in config:
|
||||
self.gpu_rank = config['rank']
|
||||
else:
|
||||
self.gpu_rank = 0
|
||||
|
||||
self.epoch_idx = 0
|
||||
self.min_gen_val_loss = float('inf')
|
||||
self.best_epoch_idx = 0
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
self.checkpoint_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||
|
||||
self.setup_wandb()
|
||||
|
||||
def setup_wandb(self):
|
||||
if self.gpu_rank == 0:
|
||||
print("[INFO] Set wandb logging on rank {}".format(0))
|
||||
run = wandb.init(
|
||||
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
|
||||
else:
|
||||
run = None
|
||||
self.run = run
|
||||
|
||||
def forward(self, batch, eval=False):
|
||||
return NotImplementedError
|
||||
|
||||
def train(self, dataset, dataset_eval):
|
||||
batch_size = self.config['batch_size']
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler = tud.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.config['num_gpus'],
|
||||
rank=self.gpu_rank
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
data_loader = tud.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.config['training'] and not self.config['parallel'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=self.config['num_workers'],
|
||||
sampler=sampler
|
||||
)
|
||||
|
||||
start_epoch_idx = self.epoch_idx
|
||||
num_iter_epoch = self.config['num_iter_per_epoch']
|
||||
if self.config['display']:
|
||||
log.info(f'{num_iter_epoch} iter per epoch.')
|
||||
|
||||
num_epochs = self.config['num_epochs']
|
||||
|
||||
# Perform validation before training
|
||||
if self.config['eval_first']:
|
||||
_ = self.val(dataset_eval)
|
||||
|
||||
for epoch_idx in range(start_epoch_idx, num_epochs):
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler.set_epoch(epoch_idx)
|
||||
self.epoch_idx = epoch_idx
|
||||
|
||||
if self.config['display']:
|
||||
log.info(f'starting epoch {epoch_idx}')
|
||||
log.info('training')
|
||||
|
||||
self.model.train()
|
||||
|
||||
num_batch = 0
|
||||
next_logging_pct = .1
|
||||
start_time = time.time()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for batch in data_loader:
|
||||
num_batch += 1
|
||||
pct = num_batch / num_iter_epoch * 100
|
||||
iter_now = num_iter_epoch * epoch_idx + num_batch
|
||||
|
||||
output = self.forward(batch)
|
||||
|
||||
losses = output['losses']
|
||||
|
||||
# optimizer step
|
||||
losses['tot_loss'] /= self.config['batch_multiply']
|
||||
# debug
|
||||
if self.config['debugging']:
|
||||
log.info('try backward')
|
||||
|
||||
losses['tot_loss'].backward()
|
||||
if self.config['clip_grad_value'] > 0:
|
||||
clip_grad_value_(self.model.parameters(), self.config['clip_grad_value'])
|
||||
if self.config['debugging']:
|
||||
log.info('backward done')
|
||||
|
||||
if iter_now % self.config['batch_multiply'] == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self.scheduler.step()
|
||||
|
||||
# display and eval
|
||||
if pct >= next_logging_pct:
|
||||
if self.config['display']:
|
||||
loss_to_print = ''
|
||||
for key in losses:
|
||||
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||
print(
|
||||
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||
)
|
||||
if self.config['print_output']:
|
||||
print(10 * '-' + 'responses' + 10 * '-')
|
||||
print(output['reponses'])
|
||||
print(10 * '-' + 'gt' + 10 * '-')
|
||||
print(output['gt'])
|
||||
|
||||
next_logging_pct += self.config["next_logging_pct"]
|
||||
|
||||
if self.config['debugging']:
|
||||
break
|
||||
|
||||
lr_bart, lr_other = self.scheduler.get_lr()[0], self.scheduler.get_lr()[-1]
|
||||
|
||||
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||
if self.run:
|
||||
self.run.log(
|
||||
{
|
||||
f"Train/{gen_key}": losses[gen_key].item(),
|
||||
f"Train/{elbo_global_key}": losses[elbo_global_key].item(),
|
||||
f"Train/{elbo_local_key}": losses[elbo_local_key].item(),
|
||||
"Train/total_loss": losses['tot_loss'].item(),
|
||||
},
|
||||
step=iter_now
|
||||
)
|
||||
|
||||
self.run.log(
|
||||
{"Train/lr_bart": lr_bart, "Train/lr_other": lr_other},
|
||||
step=iter_now
|
||||
)
|
||||
del losses
|
||||
del output
|
||||
|
||||
if self.config['display']:
|
||||
log.info(
|
||||
f'100%,\ttime:\t{time.time() - start_time:.2f}'
|
||||
)
|
||||
if not self.config['overfit'] and self.run:
|
||||
self.save_ckpt()
|
||||
|
||||
if not self.config['skip_eval']:
|
||||
|
||||
iter_now = num_iter_epoch * (epoch_idx + 1)
|
||||
val_losses = self.val(dataset_eval)
|
||||
|
||||
if self.config['display']:
|
||||
log.info('#'*100)
|
||||
for k in val_losses:
|
||||
log.info('Average val {} (epoch {}) = {}'.format(k, self.epoch_idx, val_losses[k]))
|
||||
log.info('#'*100)
|
||||
|
||||
gen_val_loss = val_losses[gen_key]
|
||||
|
||||
if gen_val_loss < self.min_gen_val_loss:
|
||||
self.min_gen_val_loss = gen_val_loss
|
||||
self.best_epoch_idx = epoch_idx
|
||||
# Log the best model w.r.t. the validation data
|
||||
if self.run and self.config['save_ckpt']:
|
||||
self.save_ckpt_best()
|
||||
|
||||
if self.run:
|
||||
|
||||
self.run.log(
|
||||
{
|
||||
f"Val/{gen_key}": val_losses[gen_key],
|
||||
f"Val/{elbo_global_key}": val_losses[elbo_global_key],
|
||||
f"Val/{elbo_local_key}": val_losses[elbo_local_key],
|
||||
"Val/total_loss": val_losses['tot_loss'],
|
||||
"Val/min_gen_loss": self.min_gen_val_loss
|
||||
},
|
||||
step=iter_now
|
||||
)
|
||||
|
||||
if self.config['parallel']:
|
||||
if self.config['dp_type'] == 'dp':
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
dist.barrier()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
|
||||
if self.config['display']:
|
||||
log.info('Stop for reaching stop_epochs.')
|
||||
break
|
||||
if self.config['display']:
|
||||
log.info(f'Best validation loss was reached at epoch {self.best_epoch_idx}.')
|
||||
|
||||
def val(self, dataset):
|
||||
total_loss_val = 0.0
|
||||
total_gen_loss_val = 0.0
|
||||
total_elbo_global_val = 0.0
|
||||
total_elbo_local_val = 0.0
|
||||
num_batch_val = 0
|
||||
next_logging_pct_val = 0.05
|
||||
|
||||
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||
|
||||
# Prepare the dataloader
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler_val = tud.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.config['num_gpus'],
|
||||
rank=self.gpu_rank
|
||||
)
|
||||
sampler_val.set_epoch(self.epoch_idx)
|
||||
else:
|
||||
sampler_val = None
|
||||
|
||||
data_loader_val = tud.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.config['batch_size'],
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=self.config['num_workers'],
|
||||
sampler=sampler_val
|
||||
)
|
||||
|
||||
if self.config['parallel'] and self.config['dp_type'] == 'dp':
|
||||
num_iter_per_epoch_val = int(np.ceil(len(dataset) / self.config['batch_size']))
|
||||
else:
|
||||
num_iter_per_epoch_val = int(np.ceil(len(dataset) / (self.config['batch_size'] * self.config['num_gpus'])))
|
||||
|
||||
|
||||
self.model.eval()
|
||||
|
||||
if self.gpu_rank == 0:
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader_val:
|
||||
num_batch_val += 1
|
||||
|
||||
pct = num_batch_val / num_iter_per_epoch_val * 100
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.forward(batch)
|
||||
|
||||
losses = output['losses']
|
||||
|
||||
losses['tot_loss'] /= self.config['batch_multiply']
|
||||
losses[elbo_global_key] /= self.config['batch_multiply']
|
||||
losses[elbo_local_key] /= self.config['batch_multiply']
|
||||
losses[gen_key] /= self.config['batch_multiply']
|
||||
|
||||
total_loss_val += losses['tot_loss'].item()
|
||||
total_gen_loss_val += losses[gen_key].item()
|
||||
total_elbo_global_val += losses[elbo_global_key].item()
|
||||
total_elbo_local_val += losses[elbo_local_key].item()
|
||||
|
||||
# display and eval
|
||||
if pct >= next_logging_pct_val:
|
||||
if self.config['display']:
|
||||
loss_to_print = ''
|
||||
for key in losses:
|
||||
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||
print(
|
||||
f'[{int(pct)}%][Validating][Iter : {num_batch_val}/{num_iter_per_epoch_val}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||
)
|
||||
|
||||
next_logging_pct_val += self.config["next_logging_pct"]
|
||||
loss_val = total_loss_val / num_batch_val
|
||||
gen_loss_val = total_gen_loss_val / num_batch_val
|
||||
elbo_global_val = total_elbo_global_val / num_batch_val
|
||||
elbo_local_val = total_elbo_local_val / num_batch_val
|
||||
|
||||
losses_val = {
|
||||
'tot_loss': loss_val,
|
||||
elbo_global_key: elbo_global_val,
|
||||
elbo_local_key: elbo_local_val,
|
||||
gen_key: gen_loss_val
|
||||
}
|
||||
self.model.train()
|
||||
return losses_val
|
||||
|
||||
|
||||
def save_eval_results(self, split, epoch_idx, metrics_results):
|
||||
|
||||
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||
with open(metrics_filename, 'w') as f:
|
||||
json.dump(metrics_results, f)
|
||||
log.info(f'Results of metrics saved to {metrics_filename}')
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
if len(self.metrics_queue) == self.metrics_queue.maxlen:
|
||||
todel = self.metrics_queue.popleft()
|
||||
os.remove(todel)
|
||||
self.metrics_queue.append(metrics_filename)
|
||||
|
||||
if epoch_idx == 'best':
|
||||
self.copy_best_predictions(split)
|
||||
|
||||
def copy_best_results(self, split, epoch_idx):
|
||||
to_print = 'Copy '
|
||||
|
||||
if not self.config['skip_saving_ckpt']:
|
||||
ckpt_path = osp.join(self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
|
||||
best_ckpt_path = ckpt_path.replace(f'{epoch_idx}.ckpt', 'best.ckpt')
|
||||
shutil.copyfile(ckpt_path, best_ckpt_path)
|
||||
to_print += best_ckpt_path + ' '
|
||||
|
||||
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||
best_metric_filename = metrics_filename.replace(f'{epoch_idx}.json', 'best.json')
|
||||
shutil.copyfile(metrics_filename, best_metric_filename)
|
||||
to_print += best_metric_filename + ' '
|
||||
|
||||
log.info(to_print)
|
||||
|
||||
|
||||
def set_ckpt(self, ckpt_dict):
|
||||
if self.config['parallel']:
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
former_dict = {k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
|
||||
|
||||
if self.config['display']:
|
||||
log.info("number of keys transferred: %d" % len(former_dict))
|
||||
assert len(former_dict.keys()) > 0
|
||||
|
||||
model_state_dict.update(former_dict)
|
||||
|
||||
model.load_state_dict(model_state_dict)
|
||||
if self.config['display']:
|
||||
log.info('loaded model')
|
||||
del model_state_dict, former_dict
|
||||
|
||||
# if not self.config['uses_new_optimizer']:
|
||||
if not self.config['generating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
|
||||
if not self.config['restarts']:
|
||||
self.epoch_idx = ckpt_dict['epoch_idx'] + 1
|
||||
|
||||
if not self.config['resets_min_val_loss']:
|
||||
self.min_gen_val_loss = ckpt_dict['min_gen_val_loss']
|
||||
|
||||
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
|
||||
if self.config['display']:
|
||||
log.info('loaded optimizer')
|
||||
if 'scheduler' in ckpt_dict:
|
||||
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * self.config['num_iter_per_epoch']
|
||||
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
|
||||
|
||||
del ckpt_dict
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def save_ckpt(self):
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
|
||||
log.info(f'saving checkpoint {ckpt_path}')
|
||||
ckpt = self.get_ckpt()
|
||||
if self.config['skip_saving_ckpt']:
|
||||
return ckpt_path
|
||||
torch_version = float(torch.__version__[:3])
|
||||
if torch_version - 1.4 > 1e-3:
|
||||
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
torch.save(ckpt, f=ckpt_path)
|
||||
del ckpt
|
||||
if not self.config['parallel']:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
||||
todel = self.checkpoint_queue.popleft()
|
||||
os.remove(todel)
|
||||
self.checkpoint_queue.append(ckpt_path)
|
||||
|
||||
def save_ckpt_best(self):
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||
log.info(f'saving checkpoint {ckpt_path}')
|
||||
ckpt = self.get_ckpt()
|
||||
torch.save(ckpt, f=ckpt_path)
|
||||
del ckpt
|
||||
return ckpt_path
|
||||
|
||||
def load_ckpt_best(self):
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||
if not osp.exists(ckpt_path):
|
||||
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||
if len(ckpt_paths) == 0:
|
||||
if self.config['display']:
|
||||
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||
return
|
||||
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
|
||||
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||
if self.config['display']:
|
||||
log.info(f'loading checkpoint {ckpt_path}')
|
||||
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||
|
||||
def get_ckpt(self):
|
||||
ckpt = {
|
||||
'epoch_idx': self.epoch_idx,
|
||||
'min_gen_val_loss': self.min_gen_val_loss,
|
||||
'seed': self.config['random_seed'],
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict()
|
||||
}
|
||||
ckpt['model_state_dict'] = self.model.module.state_dict()
|
||||
return ckpt
|
||||
|
||||
def load_ckpt(self, ckpt_path=None):
|
||||
if not ckpt_path:
|
||||
if self.config['generating']: # or self.config['start_ckpt_for_generating']:
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||
else:
|
||||
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||
if len(ckpt_paths) == 0:
|
||||
if self.config['display']:
|
||||
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||
return
|
||||
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
|
||||
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||
|
||||
if self.config['display']:
|
||||
log.info(f'loading checkpoint {ckpt_path}')
|
||||
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
|
||||
if re.search(r"(\d+)", epoch_name):
|
||||
self.checkpoint_queue.append(ckpt_path)
|
||||
metrics_filename = osp.join(self.config['log_dir'], f'metrics_{epoch_name}.json')
|
||||
if osp.exists(metrics_filename):
|
||||
self.metrics_queue.append(metrics_filename)
|
||||
|
||||
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||
|
||||
def match_model_key(self, pretrained_dict, model_dict):
|
||||
matched_dict = dict()
|
||||
not_found = []
|
||||
for key in pretrained_dict:
|
||||
if key in model_dict:
|
||||
matched_key = key
|
||||
elif key.startswith('encoder.') and key[8:] in model_dict:
|
||||
matched_key = key[8:]
|
||||
elif key.startswith('module.') and key[7:] in model_dict:
|
||||
matched_key = key[7:]
|
||||
elif 'encoder.' + key in model_dict:
|
||||
matched_key = 'encoder.' + key
|
||||
elif 'module.' + key in model_dict:
|
||||
matched_key = 'module.' + key
|
||||
else:
|
||||
not_found.append(key)
|
||||
continue
|
||||
matched_dict[matched_key] = pretrained_dict[key]
|
||||
print("Keys from pretrained_dict that were not found in model_dict:\n", not_found)
|
||||
return matched_dict
|
337
runners/runner_avsd.py
Normal file
337
runners/runner_avsd.py
Normal file
|
@ -0,0 +1,337 @@
|
|||
import time
|
||||
import os
|
||||
import glog as log
|
||||
import numpy as np
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from runners.runner import Runner
|
||||
from copy import deepcopy
|
||||
from optim_utils import init_optim
|
||||
from transformers.models.bart.configuration_bart import BartConfig
|
||||
from models.avsd_bart import AVSDBart
|
||||
|
||||
from custom_datasets.avsd import build_input_from_segments
|
||||
from time import time
|
||||
|
||||
|
||||
class AVSDRunner(Runner):
|
||||
def __init__(self, config, tokenizer, vocab_size):
|
||||
super(AVSDRunner, self).__init__(config)
|
||||
bart_config = BartConfig.from_json_file(self.config['bart_config'])
|
||||
|
||||
self.model = AVSDBart.from_pretrained(
|
||||
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
|
||||
|
||||
# Resize the embedding to match the vocab with additional special toks
|
||||
# This takes care of resizing weights of related parts of the network
|
||||
# pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
# print(pytorch_total_params)
|
||||
|
||||
if vocab_size != bart_config.vocab_size:
|
||||
self.model.resize_token_embeddings(vocab_size)
|
||||
|
||||
self.model.to(self.config['device'])
|
||||
if not self.config['generating']:
|
||||
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def forward(self, batch):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].cuda()
|
||||
|
||||
########################################################
|
||||
input_ids = batch['input_ids']
|
||||
video_place_holder_ids = batch['video_place_holder_ids']
|
||||
i3d_rgb = batch['i3d_rgb']
|
||||
i3d_flow = batch['i3d_flow']
|
||||
sam = batch['sam']
|
||||
vggish = batch['vggish']
|
||||
lm_labels = batch['lm_labels']
|
||||
input_mask = batch['input_mask']
|
||||
|
||||
i3d_rgb_interval = batch['i3d_rgb_interval']
|
||||
i3d_flow_interval = batch['i3d_flow_interval']
|
||||
sam_interval = batch['sam_interval']
|
||||
audio_interval = batch['audio_interval']
|
||||
history_intervals = batch['history_intervals']
|
||||
question_intervals = batch['question_intervals']
|
||||
vis_state_vector_idx = batch['vis_state_vector_idx']
|
||||
history_state_vector_idx = batch['history_state_vector_idx']
|
||||
question_state_vector_idx = batch['question_state_vector_idx']
|
||||
|
||||
########################################################
|
||||
bart_output = self.model(
|
||||
input_ids=input_ids,
|
||||
video_place_holder_ids=video_place_holder_ids,
|
||||
i3d_rgb=i3d_rgb,
|
||||
i3d_flow=i3d_flow,
|
||||
sam=sam,
|
||||
vggish=vggish,
|
||||
attention_mask=input_mask,
|
||||
labels=lm_labels,
|
||||
i3d_rgb_interval=i3d_rgb_interval,
|
||||
i3d_flow_interval=i3d_flow_interval,
|
||||
sam_interval=sam_interval,
|
||||
audio_interval=audio_interval,
|
||||
history_intervals=history_intervals,
|
||||
question_intervals=question_intervals,
|
||||
vis_state_vector_idx=vis_state_vector_idx,
|
||||
history_state_vector_idx=history_state_vector_idx,
|
||||
question_state_vector_idx=question_state_vector_idx,
|
||||
output_attentions=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
output = {}
|
||||
|
||||
if self.config['print_output']:
|
||||
logits = bart_output['logits']
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
preds = torch.topk(probs, 1)[1].squeeze(-1)
|
||||
preds = preds.tolist()
|
||||
lm_labels_list = lm_labels[:, 1:].tolist()
|
||||
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
|
||||
reponses = ''
|
||||
labels = ''
|
||||
for pred, label in zip(preds, lm_labels_list):
|
||||
reponses += self.tokenizer.decode(pred) + '\n'
|
||||
labels += self.tokenizer.decode(label) + '\n'
|
||||
|
||||
output['reponses'] = reponses
|
||||
output['gt'] = labels
|
||||
|
||||
|
||||
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||
gen_loss = bart_output['gen_loss']
|
||||
gen_loss = self.config['gen_coeff'] * gen_loss
|
||||
|
||||
|
||||
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||
if bart_output['elbo_loss_global'] is not None:
|
||||
elbo_global_loss = bart_output['elbo_loss_global']
|
||||
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
|
||||
else:
|
||||
elbo_global_loss = torch.tensor(0.0)
|
||||
|
||||
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||
if bart_output['elbo_loss_local'] is not None:
|
||||
elbo_local_loss = bart_output['elbo_loss_local']
|
||||
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
|
||||
else:
|
||||
elbo_local_loss = torch.tensor(0.0)
|
||||
|
||||
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
|
||||
|
||||
output['losses'] = {
|
||||
gen_key: gen_loss,
|
||||
elbo_local_key: elbo_local_loss,
|
||||
elbo_global_key: elbo_global_loss,
|
||||
'tot_loss': total_loss
|
||||
}
|
||||
del bart_output
|
||||
return output
|
||||
|
||||
|
||||
def generate(self, dataset, tag, tokenizer, gen_subset_num=None):
|
||||
|
||||
self.model.eval()
|
||||
responses = {}
|
||||
i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = tokenizer.convert_tokens_to_ids(
|
||||
['<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
|
||||
|
||||
# Generate the repsonse for each round
|
||||
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
|
||||
with torch.no_grad():
|
||||
for counter, dialog in enumerate(dataset):
|
||||
start_time = time()
|
||||
vid = dialog['vid']
|
||||
|
||||
i3d_rgb = np.load(os.path.join(self.config['avsd_i3d_rgb_test'], vid + '.npy'))
|
||||
i3d_flow = np.load(os.path.join(self.config['avsd_i3d_flow_test'], vid + '.npy'))
|
||||
sam = np.load(os.path.join(self.config['avsd_objects_test'], vid + '.npy'))
|
||||
vggish = np.load(os.path.join(self.config['avsd_audio_test'], vid + '.npy'))
|
||||
|
||||
min_length = min([self.config['vis_feat_length'], i3d_rgb.shape[0], i3d_flow.shape[0], sam.shape[0], vggish.shape[0]])
|
||||
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb.shape[0] - 1, min_length)).astype(int)
|
||||
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow.shape[0] - 1, min_length)).astype(int)
|
||||
sample_idx_sam = np.round(np.linspace(0, sam.shape[0] - 1, min_length)).astype(int)
|
||||
sample_idx_vggish = np.round(np.linspace(0, vggish.shape[0] - 1, min_length)).astype(int)
|
||||
|
||||
i3d_rgb = torch.from_numpy(i3d_rgb[sample_idx_i3d_rgb, :]).float()
|
||||
i3d_flow = torch.from_numpy(i3d_flow[sample_idx_i3d_flow, :]).float()
|
||||
sam = torch.from_numpy(sam[sample_idx_sam, :]).float()
|
||||
vggish = torch.from_numpy(vggish[sample_idx_vggish, :]).float()
|
||||
|
||||
dummy = torch.ones((1, min_length)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((1, 1)) * i3d_rgb_sep, dummy,
|
||||
torch.ones((1, 1)) * i3d_flow_sep, dummy,
|
||||
torch.ones((1, 1)) * sam_sep, dummy,
|
||||
torch.ones((1, 1)) * audio_sep, dummy,
|
||||
], dim=-1).long()
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
|
||||
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
|
||||
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
|
||||
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
|
||||
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
|
||||
|
||||
|
||||
response = self.beam_search_generation(
|
||||
dialog['caption'], dialog['history'],
|
||||
i3d_rgb, i3d_flow, sam, vggish,
|
||||
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
|
||||
vis_state_vector_idx, video_place_holder_ids, tokenizer)
|
||||
|
||||
# Decode the response
|
||||
response = self.tokenizer.decode(response)
|
||||
responses[vid] = response
|
||||
# all_graphs[vid] = graphs
|
||||
time_elapsed = int(time() - start_time)
|
||||
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
|
||||
|
||||
# Create a file with all responses
|
||||
with open(self.config['avsd_test_dstc{}'.format(self.config['dstc'])], 'r') as f:
|
||||
test_data = json.load(f)
|
||||
test_dialogs = deepcopy(test_data['dialogs'])
|
||||
# Filter the predicted dialogs
|
||||
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
|
||||
|
||||
for i, dialog in enumerate(test_dialogs):
|
||||
vid_id = dialog['image_id']
|
||||
gen_response = responses[vid_id]
|
||||
round_num_to_answer = len(dialog['dialog'])-1
|
||||
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
|
||||
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
|
||||
test_dialogs[i] = dialog
|
||||
|
||||
# Log the file
|
||||
file_name = 'results_dstc{}_beam_depth_{}'.format(self.config['dstc'], self.config['beam_depth'])
|
||||
if gen_subset_num is not None:
|
||||
file_name += f'-part_{gen_subset_num}'
|
||||
file_name = f'{tag}_' + file_name
|
||||
output_path = os.path.join(self.config['output_dir_dstc{}'.format(self.config['dstc'])], file_name + '.json')
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump({'dialogs': test_dialogs}, f, indent=4)
|
||||
log.info('Results logged to {}'.format(output_path))
|
||||
print(os.getcwd())
|
||||
# Switch back to training mode
|
||||
self.model.train()
|
||||
|
||||
|
||||
def beam_search_generation(
|
||||
self, caption, history,
|
||||
i3d_rgb, i3d_flow, sam, vggish,
|
||||
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
|
||||
vis_state_vector_idx, video_place_holder_ids, tokenizer):
|
||||
|
||||
eos_token = tokenizer.eos_token_id
|
||||
unk_token = tokenizer.unk_token_id
|
||||
question_sep = tokenizer.convert_tokens_to_ids('<s5>')
|
||||
|
||||
gen_ans = [eos_token]
|
||||
hyplist = [([], 0.0, [eos_token])]
|
||||
best_state = None
|
||||
comp_hyplist = []
|
||||
|
||||
i3d_rgb = i3d_rgb.unsqueeze(0).cuda()
|
||||
i3d_flow = i3d_flow.unsqueeze(0).cuda()
|
||||
sam = sam.unsqueeze(0).cuda()
|
||||
vggish = vggish.unsqueeze(0).cuda()
|
||||
video_place_holder_ids = video_place_holder_ids.cuda()
|
||||
text_shift_len = video_place_holder_ids.size(-1)
|
||||
|
||||
drop_caption = self.config['dstc'] == 10
|
||||
instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
|
||||
|
||||
input_ids = torch.tensor(instance['input_ids'])
|
||||
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
|
||||
history_intervals = [[0 + text_shift_len, history_end.item() + text_shift_len]] # The last token is the question state token (not part of the history)
|
||||
question_intervals = [[history_end.item() + text_shift_len, input_ids.size(0) + text_shift_len]]
|
||||
|
||||
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
|
||||
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
|
||||
|
||||
input_ids = input_ids.long().cuda().unsqueeze(0)
|
||||
encoder_outputs = None
|
||||
|
||||
for i in range(self.config['max_generation_length']):
|
||||
new_hyplist = []
|
||||
argmin = 0
|
||||
for out, lp, st in hyplist:
|
||||
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
||||
|
||||
bart_output = self.model(
|
||||
input_ids=input_ids,
|
||||
video_place_holder_ids=video_place_holder_ids,
|
||||
i3d_rgb=i3d_rgb,
|
||||
i3d_flow=i3d_flow,
|
||||
sam=sam,
|
||||
vggish=vggish,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
i3d_rgb_interval=i3d_rgb_interval,
|
||||
i3d_flow_interval=i3d_flow_interval,
|
||||
sam_interval=sam_interval,
|
||||
audio_interval=audio_interval,
|
||||
history_intervals=history_intervals,
|
||||
question_intervals=question_intervals,
|
||||
vis_state_vector_idx=vis_state_vector_idx,
|
||||
history_state_vector_idx=history_state_vector_idx,
|
||||
question_state_vector_idx=question_state_vector_idx,
|
||||
output_attentions=True,
|
||||
generate=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = [
|
||||
bart_output['encoder_last_hidden_state'],
|
||||
bart_output['encoder_hidden_states'],
|
||||
bart_output['encoder_attentions'],
|
||||
bart_output['encoder_QAs_local'],
|
||||
bart_output['encoder_PAs_local'],
|
||||
bart_output['encoder_QA_global'],
|
||||
bart_output['encoder_PA_global'],
|
||||
bart_output['encoder_state_vectors']
|
||||
]
|
||||
|
||||
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
||||
logp = F.log_softmax(logits, dim=0)
|
||||
lp_vec = logp.cpu().data.numpy() + lp
|
||||
if i >= self.config['min_generation_length']:
|
||||
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
|
||||
comp_hyplist.append((out, new_lp))
|
||||
if best_state is None or best_state < new_lp:
|
||||
best_state = new_lp
|
||||
count = 1
|
||||
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
||||
if o in [eos_token, unk_token]:
|
||||
continue
|
||||
new_lp = lp_vec[o]
|
||||
if len(new_hyplist) == self.config['beam_depth']:
|
||||
if new_hyplist[argmin][1] < new_lp:
|
||||
new_st = deepcopy(st)
|
||||
new_st.append(int(o))
|
||||
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
||||
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
new_st = deepcopy(st)
|
||||
new_st.append(int(o))
|
||||
new_hyplist.append((out + [o], new_lp, new_st))
|
||||
if len(new_hyplist) == self.config['beam_depth']:
|
||||
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||
count += 1
|
||||
hyplist = new_hyplist
|
||||
|
||||
if len(comp_hyplist) > 0:
|
||||
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
||||
return maxhyps[0][0]
|
||||
else:
|
||||
return []
|
300
runners/runner_nextqa.py
Normal file
300
runners/runner_nextqa.py
Normal file
|
@ -0,0 +1,300 @@
|
|||
import time
|
||||
import os
|
||||
import glog as log
|
||||
import numpy as np
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from runners.runner import Runner
|
||||
from copy import deepcopy
|
||||
from optim_utils import init_optim
|
||||
from transformers.models.bart.configuration_bart import BartConfig
|
||||
from models.nextqa_bart import AVSDBart
|
||||
from time import time
|
||||
|
||||
|
||||
def tokenize(obj, tokenizer):
|
||||
if isinstance(obj, str):
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
if isinstance(obj, dict):
|
||||
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||
return list(tokenize(o) for o in obj)
|
||||
|
||||
|
||||
class NEXTQARunner(Runner):
|
||||
def __init__(self, config, tokenizer, vocab_size):
|
||||
super(NEXTQARunner, self).__init__(config)
|
||||
bart_config = BartConfig.from_json_file(self.config['bart_config'])
|
||||
|
||||
self.model = AVSDBart.from_pretrained(
|
||||
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
|
||||
|
||||
# Resize the embedding to match the vocab with additional special toks
|
||||
# This takes care of resizing weights of related parts of the network
|
||||
|
||||
if vocab_size != bart_config.vocab_size:
|
||||
self.model.resize_token_embeddings(vocab_size)
|
||||
|
||||
self.model.to(self.config['device'])
|
||||
if not self.config['generating']:
|
||||
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def forward(self, batch):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].cuda()
|
||||
|
||||
########################################################
|
||||
input_ids = batch['input_ids']
|
||||
video_place_holder_ids = batch['video_place_holder_ids']
|
||||
app_feats = batch['app_feats']
|
||||
mot_feats = batch['mot_feats']
|
||||
lm_labels = batch['lm_labels']
|
||||
input_mask = batch['input_mask']
|
||||
|
||||
app_interval = batch['app_interval']
|
||||
mot_interval = batch['mot_interval']
|
||||
question_intervals = batch['question_intervals']
|
||||
vis_state_vector_idx = batch['vis_state_vector_idx']
|
||||
question_state_vector_idx = batch['question_state_vector_idx']
|
||||
########################################################
|
||||
|
||||
bart_output = self.model(
|
||||
input_ids=input_ids,
|
||||
video_place_holder_ids=video_place_holder_ids,
|
||||
i3d_rgb=app_feats,
|
||||
i3d_flow=mot_feats,
|
||||
attention_mask=input_mask,
|
||||
labels=lm_labels,
|
||||
i3d_rgb_interval=app_interval,
|
||||
i3d_flow_interval=mot_interval,
|
||||
question_intervals=question_intervals,
|
||||
vis_state_vector_idx=vis_state_vector_idx,
|
||||
question_state_vector_idx=question_state_vector_idx,
|
||||
output_attentions=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
output = {}
|
||||
|
||||
if self.config['print_output']:
|
||||
logits = bart_output['logits']
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
preds = torch.topk(probs, 1)[1].squeeze(-1)
|
||||
preds = preds.tolist()
|
||||
lm_labels_list = lm_labels[:, 1:].tolist()
|
||||
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
|
||||
reponses = ''
|
||||
labels = ''
|
||||
for pred, label in zip(preds, lm_labels_list):
|
||||
reponses += self.tokenizer.decode(pred) + '\n'
|
||||
labels += self.tokenizer.decode(label) + '\n'
|
||||
|
||||
output['reponses'] = reponses
|
||||
output['gt'] = labels
|
||||
|
||||
|
||||
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||
gen_loss = bart_output['gen_loss']
|
||||
gen_loss = self.config['gen_coeff'] * gen_loss
|
||||
|
||||
|
||||
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||
if bart_output['elbo_loss_global'] is not None:
|
||||
elbo_global_loss = bart_output['elbo_loss_global']
|
||||
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
|
||||
else:
|
||||
elbo_global_loss = torch.tensor(0.0)
|
||||
|
||||
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||
if bart_output['elbo_loss_local'] is not None:
|
||||
elbo_local_loss = bart_output['elbo_loss_local']
|
||||
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
|
||||
else:
|
||||
elbo_local_loss = torch.tensor(0.0)
|
||||
|
||||
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
|
||||
|
||||
output['losses'] = {
|
||||
gen_key: gen_loss,
|
||||
elbo_local_key: elbo_local_loss,
|
||||
elbo_global_key: elbo_global_loss,
|
||||
'tot_loss': total_loss
|
||||
}
|
||||
del bart_output
|
||||
return output
|
||||
|
||||
|
||||
def generate(self, dataset, app_feats, mot_feats, tag, tokenizer, start_idx_gen, end_idx_gen, gen_subset_num=None):
|
||||
|
||||
self.model.eval()
|
||||
results = {}
|
||||
app_sep, mot_sep, ph_token = tokenizer.convert_tokens_to_ids(
|
||||
['<s0>', '<s1>', '<place_holder>'])
|
||||
|
||||
# Generate the repsonse for each round
|
||||
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
|
||||
with torch.no_grad():
|
||||
counter = 0
|
||||
for idx in range(start_idx_gen, end_idx_gen):
|
||||
start_time = time()
|
||||
cur_sample = dataset.loc[idx]
|
||||
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||
if video_name not in results:
|
||||
results[video_name] = {}
|
||||
|
||||
input_ids = tokenize(ques, tokenizer)
|
||||
|
||||
app_feat = app_feats[video_name]
|
||||
app_feat = torch.from_numpy(app_feat).type(torch.float32)
|
||||
|
||||
mot_feat = mot_feats[video_name]
|
||||
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
|
||||
|
||||
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
|
||||
|
||||
# Add state tokens
|
||||
input_ids.insert(0, ques_state)
|
||||
|
||||
input_ids = torch.Tensor(input_ids).long()
|
||||
|
||||
dummy = torch.ones((1, 16)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((1, 1)) * app_sep, dummy,
|
||||
torch.ones((1, 1)) * mot_sep, dummy,
|
||||
], dim=-1).long()
|
||||
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
app_interval = [0, 16 + 1] # the last token is not part of this modality
|
||||
mot_interval = [16 + 1, 2 * 16 + 2]
|
||||
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
|
||||
|
||||
response = self.beam_search_generation(
|
||||
input_ids,
|
||||
app_feat, mot_feat,
|
||||
app_interval, mot_interval,
|
||||
vis_state_vector_idx, video_place_holder_ids, tokenizer)
|
||||
|
||||
# Decode the response
|
||||
response = self.tokenizer.decode(response)
|
||||
|
||||
results[video_name][qid] = response
|
||||
time_elapsed = int(time() - start_time)
|
||||
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
|
||||
counter += 1
|
||||
|
||||
# Create a file with all responses
|
||||
file_name = 'results_nextqa_beam_depth_{}'.format(self.config['beam_depth'])
|
||||
if gen_subset_num is not None:
|
||||
file_name += f'-part_{gen_subset_num}'
|
||||
file_name = f'{tag}_' + file_name
|
||||
output_path = os.path.join(self.config['output_dir_nextqa'], file_name + '.json')
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=4)
|
||||
log.info('Results logged to {}'.format(output_path))
|
||||
print(os.getcwd())
|
||||
# Switch back to training mode
|
||||
self.model.train()
|
||||
|
||||
|
||||
def beam_search_generation(
|
||||
self, input_ids,
|
||||
app_feat, mot_feat,
|
||||
app_interval, mot_interval,
|
||||
vis_state_vector_idx, video_place_holder_ids, tokenizer):
|
||||
|
||||
eos_token = tokenizer.eos_token_id
|
||||
unk_token = tokenizer.unk_token_id
|
||||
question_sep = tokenizer.convert_tokens_to_ids('<s2>')
|
||||
|
||||
gen_ans = [eos_token]
|
||||
hyplist = [([], 0.0, [eos_token])]
|
||||
best_state = None
|
||||
comp_hyplist = []
|
||||
|
||||
app_feat = app_feat.unsqueeze(0).cuda()
|
||||
mot_feat = mot_feat.unsqueeze(0).cuda()
|
||||
video_place_holder_ids = video_place_holder_ids.cuda()
|
||||
text_shift_len = video_place_holder_ids.size(-1)
|
||||
|
||||
question_intervals = [[0 + text_shift_len, input_ids.size(0) + text_shift_len]] # The last token is the question state token (not part of the history)
|
||||
|
||||
question_state_vector_idx = [x[0] for x in question_intervals]
|
||||
|
||||
input_ids = input_ids.long().cuda().unsqueeze(0)
|
||||
encoder_outputs = None
|
||||
|
||||
for i in range(self.config['max_generation_length']):
|
||||
new_hyplist = []
|
||||
argmin = 0
|
||||
for out, lp, st in hyplist:
|
||||
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
||||
|
||||
bart_output = self.model(
|
||||
input_ids=input_ids,
|
||||
video_place_holder_ids=video_place_holder_ids,
|
||||
i3d_rgb=app_feat,
|
||||
i3d_flow=mot_feat,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
i3d_rgb_interval=app_interval,
|
||||
i3d_flow_interval=mot_interval,
|
||||
question_intervals=question_intervals,
|
||||
vis_state_vector_idx=vis_state_vector_idx,
|
||||
question_state_vector_idx=question_state_vector_idx,
|
||||
output_attentions=True,
|
||||
generate=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = [
|
||||
bart_output['encoder_last_hidden_state'],
|
||||
bart_output['encoder_hidden_states'],
|
||||
bart_output['encoder_attentions'],
|
||||
bart_output['encoder_QAs_local'],
|
||||
bart_output['encoder_PAs_local'],
|
||||
bart_output['encoder_QA_global'],
|
||||
bart_output['encoder_PA_global'],
|
||||
bart_output['encoder_state_vectors']
|
||||
]
|
||||
|
||||
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
||||
logp = F.log_softmax(logits, dim=0)
|
||||
lp_vec = logp.cpu().data.numpy() + lp
|
||||
if i >= self.config['min_generation_length']:
|
||||
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
|
||||
comp_hyplist.append((out, new_lp))
|
||||
if best_state is None or best_state < new_lp:
|
||||
best_state = new_lp
|
||||
count = 1
|
||||
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
||||
if o in [eos_token, unk_token]:
|
||||
continue
|
||||
new_lp = lp_vec[o]
|
||||
if len(new_hyplist) == self.config['beam_depth']:
|
||||
if new_hyplist[argmin][1] < new_lp:
|
||||
new_st = deepcopy(st)
|
||||
new_st.append(int(o))
|
||||
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
||||
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
new_st = deepcopy(st)
|
||||
new_st.append(int(o))
|
||||
new_hyplist.append((out + [o], new_lp, new_st))
|
||||
if len(new_hyplist) == self.config['beam_depth']:
|
||||
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||
count += 1
|
||||
hyplist = new_hyplist
|
||||
|
||||
if len(comp_hyplist) > 0:
|
||||
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
||||
return maxhyps[0][0]
|
||||
else:
|
||||
return []
|
Loading…
Add table
Add a link
Reference in a new issue