Make code public

This commit is contained in:
Adnen Abdessaied 2024-07-08 11:41:28 +02:00
commit 8e03ef1c38
49 changed files with 545354 additions and 0 deletions

488
runners/runner.py Normal file
View file

@ -0,0 +1,488 @@
import wandb
import os
import os.path as osp
import json
from collections import deque, OrderedDict
import time
import re
import shutil
import glob
import pickle
import gc
import numpy as np
import glog as log
import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.utils import clip_grad_value_
class Runner:
def __init__(self, config):
self.config = config
if 'rank' in config:
self.gpu_rank = config['rank']
else:
self.gpu_rank = 0
self.epoch_idx = 0
self.min_gen_val_loss = float('inf')
self.best_epoch_idx = 0
if self.config["max_ckpt_to_keep"] > 0:
self.checkpoint_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.setup_wandb()
def setup_wandb(self):
if self.gpu_rank == 0:
print("[INFO] Set wandb logging on rank {}".format(0))
run = wandb.init(
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
else:
run = None
self.run = run
def forward(self, batch, eval=False):
return NotImplementedError
def train(self, dataset, dataset_eval):
batch_size = self.config['batch_size']
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
else:
sampler = None
data_loader = tud.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=self.config['training'] and not self.config['parallel'],
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler
)
start_epoch_idx = self.epoch_idx
num_iter_epoch = self.config['num_iter_per_epoch']
if self.config['display']:
log.info(f'{num_iter_epoch} iter per epoch.')
num_epochs = self.config['num_epochs']
# Perform validation before training
if self.config['eval_first']:
_ = self.val(dataset_eval)
for epoch_idx in range(start_epoch_idx, num_epochs):
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler.set_epoch(epoch_idx)
self.epoch_idx = epoch_idx
if self.config['display']:
log.info(f'starting epoch {epoch_idx}')
log.info('training')
self.model.train()
num_batch = 0
next_logging_pct = .1
start_time = time.time()
self.optimizer.zero_grad()
for batch in data_loader:
num_batch += 1
pct = num_batch / num_iter_epoch * 100
iter_now = num_iter_epoch * epoch_idx + num_batch
output = self.forward(batch)
losses = output['losses']
# optimizer step
losses['tot_loss'] /= self.config['batch_multiply']
# debug
if self.config['debugging']:
log.info('try backward')
losses['tot_loss'].backward()
if self.config['clip_grad_value'] > 0:
clip_grad_value_(self.model.parameters(), self.config['clip_grad_value'])
if self.config['debugging']:
log.info('backward done')
if iter_now % self.config['batch_multiply'] == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# display and eval
if pct >= next_logging_pct:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
if self.config['print_output']:
print(10 * '-' + 'responses' + 10 * '-')
print(output['reponses'])
print(10 * '-' + 'gt' + 10 * '-')
print(output['gt'])
next_logging_pct += self.config["next_logging_pct"]
if self.config['debugging']:
break
lr_bart, lr_other = self.scheduler.get_lr()[0], self.scheduler.get_lr()[-1]
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
if self.run:
self.run.log(
{
f"Train/{gen_key}": losses[gen_key].item(),
f"Train/{elbo_global_key}": losses[elbo_global_key].item(),
f"Train/{elbo_local_key}": losses[elbo_local_key].item(),
"Train/total_loss": losses['tot_loss'].item(),
},
step=iter_now
)
self.run.log(
{"Train/lr_bart": lr_bart, "Train/lr_other": lr_other},
step=iter_now
)
del losses
del output
if self.config['display']:
log.info(
f'100%,\ttime:\t{time.time() - start_time:.2f}'
)
if not self.config['overfit'] and self.run:
self.save_ckpt()
if not self.config['skip_eval']:
iter_now = num_iter_epoch * (epoch_idx + 1)
val_losses = self.val(dataset_eval)
if self.config['display']:
log.info('#'*100)
for k in val_losses:
log.info('Average val {} (epoch {}) = {}'.format(k, self.epoch_idx, val_losses[k]))
log.info('#'*100)
gen_val_loss = val_losses[gen_key]
if gen_val_loss < self.min_gen_val_loss:
self.min_gen_val_loss = gen_val_loss
self.best_epoch_idx = epoch_idx
# Log the best model w.r.t. the validation data
if self.run and self.config['save_ckpt']:
self.save_ckpt_best()
if self.run:
self.run.log(
{
f"Val/{gen_key}": val_losses[gen_key],
f"Val/{elbo_global_key}": val_losses[elbo_global_key],
f"Val/{elbo_local_key}": val_losses[elbo_local_key],
"Val/total_loss": val_losses['tot_loss'],
"Val/min_gen_loss": self.min_gen_val_loss
},
step=iter_now
)
if self.config['parallel']:
if self.config['dp_type'] == 'dp':
gc.collect()
torch.cuda.empty_cache()
else:
dist.barrier()
torch.cuda.empty_cache()
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
if self.config['display']:
log.info('Stop for reaching stop_epochs.')
break
if self.config['display']:
log.info(f'Best validation loss was reached at epoch {self.best_epoch_idx}.')
def val(self, dataset):
total_loss_val = 0.0
total_gen_loss_val = 0.0
total_elbo_global_val = 0.0
total_elbo_local_val = 0.0
num_batch_val = 0
next_logging_pct_val = 0.05
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
# Prepare the dataloader
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_val = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
sampler_val.set_epoch(self.epoch_idx)
else:
sampler_val = None
data_loader_val = tud.DataLoader(
dataset=dataset,
batch_size=self.config['batch_size'],
shuffle=False,
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_val
)
if self.config['parallel'] and self.config['dp_type'] == 'dp':
num_iter_per_epoch_val = int(np.ceil(len(dataset) / self.config['batch_size']))
else:
num_iter_per_epoch_val = int(np.ceil(len(dataset) / (self.config['batch_size'] * self.config['num_gpus'])))
self.model.eval()
if self.gpu_rank == 0:
start_time = time.time()
for batch in data_loader_val:
num_batch_val += 1
pct = num_batch_val / num_iter_per_epoch_val * 100
with torch.no_grad():
output = self.forward(batch)
losses = output['losses']
losses['tot_loss'] /= self.config['batch_multiply']
losses[elbo_global_key] /= self.config['batch_multiply']
losses[elbo_local_key] /= self.config['batch_multiply']
losses[gen_key] /= self.config['batch_multiply']
total_loss_val += losses['tot_loss'].item()
total_gen_loss_val += losses[gen_key].item()
total_elbo_global_val += losses[elbo_global_key].item()
total_elbo_local_val += losses[elbo_local_key].item()
# display and eval
if pct >= next_logging_pct_val:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Validating][Iter : {num_batch_val}/{num_iter_per_epoch_val}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
next_logging_pct_val += self.config["next_logging_pct"]
loss_val = total_loss_val / num_batch_val
gen_loss_val = total_gen_loss_val / num_batch_val
elbo_global_val = total_elbo_global_val / num_batch_val
elbo_local_val = total_elbo_local_val / num_batch_val
losses_val = {
'tot_loss': loss_val,
elbo_global_key: elbo_global_val,
elbo_local_key: elbo_local_val,
gen_key: gen_loss_val
}
self.model.train()
return losses_val
def save_eval_results(self, split, epoch_idx, metrics_results):
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
with open(metrics_filename, 'w') as f:
json.dump(metrics_results, f)
log.info(f'Results of metrics saved to {metrics_filename}')
if self.config["max_ckpt_to_keep"] > 0:
if len(self.metrics_queue) == self.metrics_queue.maxlen:
todel = self.metrics_queue.popleft()
os.remove(todel)
self.metrics_queue.append(metrics_filename)
if epoch_idx == 'best':
self.copy_best_predictions(split)
def copy_best_results(self, split, epoch_idx):
to_print = 'Copy '
if not self.config['skip_saving_ckpt']:
ckpt_path = osp.join(self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
best_ckpt_path = ckpt_path.replace(f'{epoch_idx}.ckpt', 'best.ckpt')
shutil.copyfile(ckpt_path, best_ckpt_path)
to_print += best_ckpt_path + ' '
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
best_metric_filename = metrics_filename.replace(f'{epoch_idx}.json', 'best.json')
shutil.copyfile(metrics_filename, best_metric_filename)
to_print += best_metric_filename + ' '
log.info(to_print)
def set_ckpt(self, ckpt_dict):
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_state_dict = model.state_dict()
former_dict = {k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
if self.config['display']:
log.info("number of keys transferred: %d" % len(former_dict))
assert len(former_dict.keys()) > 0
model_state_dict.update(former_dict)
model.load_state_dict(model_state_dict)
if self.config['display']:
log.info('loaded model')
del model_state_dict, former_dict
# if not self.config['uses_new_optimizer']:
if not self.config['generating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
if not self.config['restarts']:
self.epoch_idx = ckpt_dict['epoch_idx'] + 1
if not self.config['resets_min_val_loss']:
self.min_gen_val_loss = ckpt_dict['min_gen_val_loss']
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.config['display']:
log.info('loaded optimizer')
if 'scheduler' in ckpt_dict:
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * self.config['num_iter_per_epoch']
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
del ckpt_dict
torch.cuda.empty_cache()
def save_ckpt(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
if self.config['skip_saving_ckpt']:
return ckpt_path
torch_version = float(torch.__version__[:3])
if torch_version - 1.4 > 1e-3:
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
else:
torch.save(ckpt, f=ckpt_path)
del ckpt
if not self.config['parallel']:
torch.cuda.empty_cache()
if self.config["max_ckpt_to_keep"] > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
os.remove(todel)
self.checkpoint_queue.append(ckpt_path)
def save_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
torch.save(ckpt, f=ckpt_path)
del ckpt
return ckpt_path
def load_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
if not osp.exists(ckpt_path):
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def get_ckpt(self):
ckpt = {
'epoch_idx': self.epoch_idx,
'min_gen_val_loss': self.min_gen_val_loss,
'seed': self.config['random_seed'],
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}
ckpt['model_state_dict'] = self.model.module.state_dict()
return ckpt
def load_ckpt(self, ckpt_path=None):
if not ckpt_path:
if self.config['generating']: # or self.config['start_ckpt_for_generating']:
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
else:
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
if re.search(r"(\d+)", epoch_name):
self.checkpoint_queue.append(ckpt_path)
metrics_filename = osp.join(self.config['log_dir'], f'metrics_{epoch_name}.json')
if osp.exists(metrics_filename):
self.metrics_queue.append(metrics_filename)
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def match_model_key(self, pretrained_dict, model_dict):
matched_dict = dict()
not_found = []
for key in pretrained_dict:
if key in model_dict:
matched_key = key
elif key.startswith('encoder.') and key[8:] in model_dict:
matched_key = key[8:]
elif key.startswith('module.') and key[7:] in model_dict:
matched_key = key[7:]
elif 'encoder.' + key in model_dict:
matched_key = 'encoder.' + key
elif 'module.' + key in model_dict:
matched_key = 'module.' + key
else:
not_found.append(key)
continue
matched_dict[matched_key] = pretrained_dict[key]
print("Keys from pretrained_dict that were not found in model_dict:\n", not_found)
return matched_dict

337
runners/runner_avsd.py Normal file
View file

@ -0,0 +1,337 @@
import time
import os
import glog as log
import numpy as np
import json
import torch
import torch.nn.functional as F
from runners.runner import Runner
from copy import deepcopy
from optim_utils import init_optim
from transformers.models.bart.configuration_bart import BartConfig
from models.avsd_bart import AVSDBart
from custom_datasets.avsd import build_input_from_segments
from time import time
class AVSDRunner(Runner):
def __init__(self, config, tokenizer, vocab_size):
super(AVSDRunner, self).__init__(config)
bart_config = BartConfig.from_json_file(self.config['bart_config'])
self.model = AVSDBart.from_pretrained(
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
# Resize the embedding to match the vocab with additional special toks
# This takes care of resizing weights of related parts of the network
# pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
# print(pytorch_total_params)
if vocab_size != bart_config.vocab_size:
self.model.resize_token_embeddings(vocab_size)
self.model.to(self.config['device'])
if not self.config['generating']:
self.optimizer, self.scheduler = init_optim(self.model, self.config)
self.tokenizer = tokenizer
def forward(self, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].cuda()
########################################################
input_ids = batch['input_ids']
video_place_holder_ids = batch['video_place_holder_ids']
i3d_rgb = batch['i3d_rgb']
i3d_flow = batch['i3d_flow']
sam = batch['sam']
vggish = batch['vggish']
lm_labels = batch['lm_labels']
input_mask = batch['input_mask']
i3d_rgb_interval = batch['i3d_rgb_interval']
i3d_flow_interval = batch['i3d_flow_interval']
sam_interval = batch['sam_interval']
audio_interval = batch['audio_interval']
history_intervals = batch['history_intervals']
question_intervals = batch['question_intervals']
vis_state_vector_idx = batch['vis_state_vector_idx']
history_state_vector_idx = batch['history_state_vector_idx']
question_state_vector_idx = batch['question_state_vector_idx']
########################################################
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=i3d_rgb,
i3d_flow=i3d_flow,
sam=sam,
vggish=vggish,
attention_mask=input_mask,
labels=lm_labels,
i3d_rgb_interval=i3d_rgb_interval,
i3d_flow_interval=i3d_flow_interval,
sam_interval=sam_interval,
audio_interval=audio_interval,
history_intervals=history_intervals,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
history_state_vector_idx=history_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
return_dict=True
)
output = {}
if self.config['print_output']:
logits = bart_output['logits']
probs = F.softmax(logits, dim=-1)
preds = torch.topk(probs, 1)[1].squeeze(-1)
preds = preds.tolist()
lm_labels_list = lm_labels[:, 1:].tolist()
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
reponses = ''
labels = ''
for pred, label in zip(preds, lm_labels_list):
reponses += self.tokenizer.decode(pred) + '\n'
labels += self.tokenizer.decode(label) + '\n'
output['reponses'] = reponses
output['gt'] = labels
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
gen_loss = bart_output['gen_loss']
gen_loss = self.config['gen_coeff'] * gen_loss
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
if bart_output['elbo_loss_global'] is not None:
elbo_global_loss = bart_output['elbo_loss_global']
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
else:
elbo_global_loss = torch.tensor(0.0)
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
if bart_output['elbo_loss_local'] is not None:
elbo_local_loss = bart_output['elbo_loss_local']
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
else:
elbo_local_loss = torch.tensor(0.0)
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
output['losses'] = {
gen_key: gen_loss,
elbo_local_key: elbo_local_loss,
elbo_global_key: elbo_global_loss,
'tot_loss': total_loss
}
del bart_output
return output
def generate(self, dataset, tag, tokenizer, gen_subset_num=None):
self.model.eval()
responses = {}
i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = tokenizer.convert_tokens_to_ids(
['<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# Generate the repsonse for each round
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
with torch.no_grad():
for counter, dialog in enumerate(dataset):
start_time = time()
vid = dialog['vid']
i3d_rgb = np.load(os.path.join(self.config['avsd_i3d_rgb_test'], vid + '.npy'))
i3d_flow = np.load(os.path.join(self.config['avsd_i3d_flow_test'], vid + '.npy'))
sam = np.load(os.path.join(self.config['avsd_objects_test'], vid + '.npy'))
vggish = np.load(os.path.join(self.config['avsd_audio_test'], vid + '.npy'))
min_length = min([self.config['vis_feat_length'], i3d_rgb.shape[0], i3d_flow.shape[0], sam.shape[0], vggish.shape[0]])
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb.shape[0] - 1, min_length)).astype(int)
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow.shape[0] - 1, min_length)).astype(int)
sample_idx_sam = np.round(np.linspace(0, sam.shape[0] - 1, min_length)).astype(int)
sample_idx_vggish = np.round(np.linspace(0, vggish.shape[0] - 1, min_length)).astype(int)
i3d_rgb = torch.from_numpy(i3d_rgb[sample_idx_i3d_rgb, :]).float()
i3d_flow = torch.from_numpy(i3d_flow[sample_idx_i3d_flow, :]).float()
sam = torch.from_numpy(sam[sample_idx_sam, :]).float()
vggish = torch.from_numpy(vggish[sample_idx_vggish, :]).float()
dummy = torch.ones((1, min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((1, 1)) * i3d_rgb_sep, dummy,
torch.ones((1, 1)) * i3d_flow_sep, dummy,
torch.ones((1, 1)) * sam_sep, dummy,
torch.ones((1, 1)) * audio_sep, dummy,
], dim=-1).long()
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
response = self.beam_search_generation(
dialog['caption'], dialog['history'],
i3d_rgb, i3d_flow, sam, vggish,
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer)
# Decode the response
response = self.tokenizer.decode(response)
responses[vid] = response
# all_graphs[vid] = graphs
time_elapsed = int(time() - start_time)
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
# Create a file with all responses
with open(self.config['avsd_test_dstc{}'.format(self.config['dstc'])], 'r') as f:
test_data = json.load(f)
test_dialogs = deepcopy(test_data['dialogs'])
# Filter the predicted dialogs
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
for i, dialog in enumerate(test_dialogs):
vid_id = dialog['image_id']
gen_response = responses[vid_id]
round_num_to_answer = len(dialog['dialog'])-1
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
test_dialogs[i] = dialog
# Log the file
file_name = 'results_dstc{}_beam_depth_{}'.format(self.config['dstc'], self.config['beam_depth'])
if gen_subset_num is not None:
file_name += f'-part_{gen_subset_num}'
file_name = f'{tag}_' + file_name
output_path = os.path.join(self.config['output_dir_dstc{}'.format(self.config['dstc'])], file_name + '.json')
with open(output_path, 'w') as f:
json.dump({'dialogs': test_dialogs}, f, indent=4)
log.info('Results logged to {}'.format(output_path))
print(os.getcwd())
# Switch back to training mode
self.model.train()
def beam_search_generation(
self, caption, history,
i3d_rgb, i3d_flow, sam, vggish,
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer):
eos_token = tokenizer.eos_token_id
unk_token = tokenizer.unk_token_id
question_sep = tokenizer.convert_tokens_to_ids('<s5>')
gen_ans = [eos_token]
hyplist = [([], 0.0, [eos_token])]
best_state = None
comp_hyplist = []
i3d_rgb = i3d_rgb.unsqueeze(0).cuda()
i3d_flow = i3d_flow.unsqueeze(0).cuda()
sam = sam.unsqueeze(0).cuda()
vggish = vggish.unsqueeze(0).cuda()
video_place_holder_ids = video_place_holder_ids.cuda()
text_shift_len = video_place_holder_ids.size(-1)
drop_caption = self.config['dstc'] == 10
instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
input_ids = torch.tensor(instance['input_ids'])
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_intervals = [[0 + text_shift_len, history_end.item() + text_shift_len]] # The last token is the question state token (not part of the history)
question_intervals = [[history_end.item() + text_shift_len, input_ids.size(0) + text_shift_len]]
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
input_ids = input_ids.long().cuda().unsqueeze(0)
encoder_outputs = None
for i in range(self.config['max_generation_length']):
new_hyplist = []
argmin = 0
for out, lp, st in hyplist:
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=i3d_rgb,
i3d_flow=i3d_flow,
sam=sam,
vggish=vggish,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
i3d_rgb_interval=i3d_rgb_interval,
i3d_flow_interval=i3d_flow_interval,
sam_interval=sam_interval,
audio_interval=audio_interval,
history_intervals=history_intervals,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
history_state_vector_idx=history_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
generate=True,
return_dict=True
)
if encoder_outputs is None:
encoder_outputs = [
bart_output['encoder_last_hidden_state'],
bart_output['encoder_hidden_states'],
bart_output['encoder_attentions'],
bart_output['encoder_QAs_local'],
bart_output['encoder_PAs_local'],
bart_output['encoder_QA_global'],
bart_output['encoder_PA_global'],
bart_output['encoder_state_vectors']
]
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
logp = F.log_softmax(logits, dim=0)
lp_vec = logp.cpu().data.numpy() + lp
if i >= self.config['min_generation_length']:
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
comp_hyplist.append((out, new_lp))
if best_state is None or best_state < new_lp:
best_state = new_lp
count = 1
for o in np.argsort(lp_vec)[::-1]: # reverse the order
if o in [eos_token, unk_token]:
continue
new_lp = lp_vec[o]
if len(new_hyplist) == self.config['beam_depth']:
if new_hyplist[argmin][1] < new_lp:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist[argmin] = (out + [o], new_lp, new_st)
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
else:
break
else:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist.append((out + [o], new_lp, new_st))
if len(new_hyplist) == self.config['beam_depth']:
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
count += 1
hyplist = new_hyplist
if len(comp_hyplist) > 0:
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
return maxhyps[0][0]
else:
return []

300
runners/runner_nextqa.py Normal file
View file

@ -0,0 +1,300 @@
import time
import os
import glog as log
import numpy as np
import json
import torch
import torch.nn.functional as F
from runners.runner import Runner
from copy import deepcopy
from optim_utils import init_optim
from transformers.models.bart.configuration_bart import BartConfig
from models.nextqa_bart import AVSDBart
from time import time
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NEXTQARunner(Runner):
def __init__(self, config, tokenizer, vocab_size):
super(NEXTQARunner, self).__init__(config)
bart_config = BartConfig.from_json_file(self.config['bart_config'])
self.model = AVSDBart.from_pretrained(
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
# Resize the embedding to match the vocab with additional special toks
# This takes care of resizing weights of related parts of the network
if vocab_size != bart_config.vocab_size:
self.model.resize_token_embeddings(vocab_size)
self.model.to(self.config['device'])
if not self.config['generating']:
self.optimizer, self.scheduler = init_optim(self.model, self.config)
self.tokenizer = tokenizer
def forward(self, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].cuda()
########################################################
input_ids = batch['input_ids']
video_place_holder_ids = batch['video_place_holder_ids']
app_feats = batch['app_feats']
mot_feats = batch['mot_feats']
lm_labels = batch['lm_labels']
input_mask = batch['input_mask']
app_interval = batch['app_interval']
mot_interval = batch['mot_interval']
question_intervals = batch['question_intervals']
vis_state_vector_idx = batch['vis_state_vector_idx']
question_state_vector_idx = batch['question_state_vector_idx']
########################################################
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=app_feats,
i3d_flow=mot_feats,
attention_mask=input_mask,
labels=lm_labels,
i3d_rgb_interval=app_interval,
i3d_flow_interval=mot_interval,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
return_dict=True
)
output = {}
if self.config['print_output']:
logits = bart_output['logits']
probs = F.softmax(logits, dim=-1)
preds = torch.topk(probs, 1)[1].squeeze(-1)
preds = preds.tolist()
lm_labels_list = lm_labels[:, 1:].tolist()
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
reponses = ''
labels = ''
for pred, label in zip(preds, lm_labels_list):
reponses += self.tokenizer.decode(pred) + '\n'
labels += self.tokenizer.decode(label) + '\n'
output['reponses'] = reponses
output['gt'] = labels
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
gen_loss = bart_output['gen_loss']
gen_loss = self.config['gen_coeff'] * gen_loss
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
if bart_output['elbo_loss_global'] is not None:
elbo_global_loss = bart_output['elbo_loss_global']
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
else:
elbo_global_loss = torch.tensor(0.0)
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
if bart_output['elbo_loss_local'] is not None:
elbo_local_loss = bart_output['elbo_loss_local']
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
else:
elbo_local_loss = torch.tensor(0.0)
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
output['losses'] = {
gen_key: gen_loss,
elbo_local_key: elbo_local_loss,
elbo_global_key: elbo_global_loss,
'tot_loss': total_loss
}
del bart_output
return output
def generate(self, dataset, app_feats, mot_feats, tag, tokenizer, start_idx_gen, end_idx_gen, gen_subset_num=None):
self.model.eval()
results = {}
app_sep, mot_sep, ph_token = tokenizer.convert_tokens_to_ids(
['<s0>', '<s1>', '<place_holder>'])
# Generate the repsonse for each round
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
with torch.no_grad():
counter = 0
for idx in range(start_idx_gen, end_idx_gen):
start_time = time()
cur_sample = dataset.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
if video_name not in results:
results[video_name] = {}
input_ids = tokenize(ques, tokenizer)
app_feat = app_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
input_ids = torch.Tensor(input_ids).long()
dummy = torch.ones((1, 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((1, 1)) * app_sep, dummy,
torch.ones((1, 1)) * mot_sep, dummy,
], dim=-1).long()
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
response = self.beam_search_generation(
input_ids,
app_feat, mot_feat,
app_interval, mot_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer)
# Decode the response
response = self.tokenizer.decode(response)
results[video_name][qid] = response
time_elapsed = int(time() - start_time)
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
counter += 1
# Create a file with all responses
file_name = 'results_nextqa_beam_depth_{}'.format(self.config['beam_depth'])
if gen_subset_num is not None:
file_name += f'-part_{gen_subset_num}'
file_name = f'{tag}_' + file_name
output_path = os.path.join(self.config['output_dir_nextqa'], file_name + '.json')
with open(output_path, 'w') as f:
json.dump(results, f, indent=4)
log.info('Results logged to {}'.format(output_path))
print(os.getcwd())
# Switch back to training mode
self.model.train()
def beam_search_generation(
self, input_ids,
app_feat, mot_feat,
app_interval, mot_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer):
eos_token = tokenizer.eos_token_id
unk_token = tokenizer.unk_token_id
question_sep = tokenizer.convert_tokens_to_ids('<s2>')
gen_ans = [eos_token]
hyplist = [([], 0.0, [eos_token])]
best_state = None
comp_hyplist = []
app_feat = app_feat.unsqueeze(0).cuda()
mot_feat = mot_feat.unsqueeze(0).cuda()
video_place_holder_ids = video_place_holder_ids.cuda()
text_shift_len = video_place_holder_ids.size(-1)
question_intervals = [[0 + text_shift_len, input_ids.size(0) + text_shift_len]] # The last token is the question state token (not part of the history)
question_state_vector_idx = [x[0] for x in question_intervals]
input_ids = input_ids.long().cuda().unsqueeze(0)
encoder_outputs = None
for i in range(self.config['max_generation_length']):
new_hyplist = []
argmin = 0
for out, lp, st in hyplist:
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=app_feat,
i3d_flow=mot_feat,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
i3d_rgb_interval=app_interval,
i3d_flow_interval=mot_interval,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
generate=True,
return_dict=True
)
if encoder_outputs is None:
encoder_outputs = [
bart_output['encoder_last_hidden_state'],
bart_output['encoder_hidden_states'],
bart_output['encoder_attentions'],
bart_output['encoder_QAs_local'],
bart_output['encoder_PAs_local'],
bart_output['encoder_QA_global'],
bart_output['encoder_PA_global'],
bart_output['encoder_state_vectors']
]
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
logp = F.log_softmax(logits, dim=0)
lp_vec = logp.cpu().data.numpy() + lp
if i >= self.config['min_generation_length']:
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
comp_hyplist.append((out, new_lp))
if best_state is None or best_state < new_lp:
best_state = new_lp
count = 1
for o in np.argsort(lp_vec)[::-1]: # reverse the order
if o in [eos_token, unk_token]:
continue
new_lp = lp_vec[o]
if len(new_hyplist) == self.config['beam_depth']:
if new_hyplist[argmin][1] < new_lp:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist[argmin] = (out + [o], new_lp, new_st)
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
else:
break
else:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist.append((out + [o], new_lp, new_st))
if len(new_hyplist) == self.config['beam_depth']:
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
count += 1
hyplist = new_hyplist
if len(comp_hyplist) > 0:
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
return maxhyps[0][0]
else:
return []