vlcn/code/core/exec.py
2022-03-30 10:46:35 +02:00

523 lines
20 KiB
Python

# --------------------------------------------------------
# mcan-vqa (Deep Modular Co-Attention Networks)
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------
from core.data.dataset import VideoQA_Dataset
from core.model.net import Net1, Net2, Net3, Net4
from core.model.optim import get_optim, adjust_lr
from core.metrics import get_acc
from tqdm import tqdm
from core.data.utils import shuffle_list
import os, json, torch, datetime, pickle, copy, shutil, time, math
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
from tensorboardX import SummaryWriter
from torch.autograd import Variable as var
class Execution:
def __init__(self, __C):
self.__C = __C
print('Loading training set ........')
__C_train = copy.deepcopy(self.__C)
setattr(__C_train, 'RUN_MODE', 'train')
self.dataset = VideoQA_Dataset(__C_train)
self.dataset_eval = None
if self.__C.EVAL_EVERY_EPOCH:
__C_eval = copy.deepcopy(self.__C)
setattr(__C_eval, 'RUN_MODE', 'val')
print('Loading validation set for per-epoch evaluation ........')
self.dataset_eval = VideoQA_Dataset(__C_eval)
self.dataset_eval.ans_list = self.dataset.ans_list
self.dataset_eval.ans_to_ix, self.dataset_eval.ix_to_ans = self.dataset.ans_to_ix, self.dataset.ix_to_ans
self.dataset_eval.token_to_ix, self.dataset_eval.pretrained_emb = self.dataset.token_to_ix, self.dataset.pretrained_emb
__C_test = copy.deepcopy(self.__C)
setattr(__C_test, 'RUN_MODE', 'test')
self.dataset_test = VideoQA_Dataset(__C_test)
self.dataset_test.ans_list = self.dataset.ans_list
self.dataset_test.ans_to_ix, self.dataset_test.ix_to_ans = self.dataset.ans_to_ix, self.dataset.ix_to_ans
self.dataset_test.token_to_ix, self.dataset_test.pretrained_emb = self.dataset.token_to_ix, self.dataset.pretrained_emb
self.writer = SummaryWriter(self.__C.TB_PATH)
def train(self, dataset, dataset_eval=None):
# Obtain needed information
data_size = dataset.data_size
token_size = dataset.token_size
ans_size = dataset.ans_size
pretrained_emb = dataset.pretrained_emb
net = self.construct_net(self.__C.MODEL_TYPE)
if os.path.isfile(self.__C.PRETRAINED_PATH) and self.__C.MODEL_TYPE == 11:
print('Loading pretrained DNC-weigths')
net.load_pretrained_weights()
net.cuda()
net.train()
# Define the multi-gpu training if needed
if self.__C.N_GPU > 1:
net = nn.DataParallel(net, device_ids=self.__C.DEVICES)
# Define the binary cross entropy loss
# loss_fn = torch.nn.BCELoss(size_average=False).cuda()
loss_fn = torch.nn.BCELoss(reduction='sum').cuda()
# Load checkpoint if resume training
if self.__C.RESUME:
print(' ========== Resume training')
if self.__C.CKPT_PATH is not None:
print('Warning: you are now using CKPT_PATH args, '
'CKPT_VERSION and CKPT_EPOCH will not work')
path = self.__C.CKPT_PATH
else:
path = self.__C.CKPTS_PATH + \
'ckpt_' + self.__C.CKPT_VERSION + \
'/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl'
# Load the network parameters
print('Loading ckpt {}'.format(path))
ckpt = torch.load(path)
print('Finish!')
net.load_state_dict(ckpt['state_dict'])
# Load the optimizer paramters
optim = get_optim(self.__C, net, data_size, ckpt['optim'], lr_base=ckpt['lr_base'])
optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH)
optim.optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = self.__C.CKPT_EPOCH
else:
if ('ckpt_' + self.__C.VERSION) in os.listdir(self.__C.CKPTS_PATH):
shutil.rmtree(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION)
os.mkdir(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION)
optim = get_optim(self.__C, net, data_size, self.__C.OPTIM)
start_epoch = 0
loss_sum = 0
named_params = list(net.named_parameters())
grad_norm = np.zeros(len(named_params))
# Define multi-thread dataloader
if self.__C.SHUFFLE_MODE in ['external']:
dataloader = Data.DataLoader(
dataset,
batch_size=self.__C.BATCH_SIZE,
shuffle=False,
num_workers=self.__C.NUM_WORKERS,
pin_memory=self.__C.PIN_MEM,
drop_last=True
)
else:
dataloader = Data.DataLoader(
dataset,
batch_size=self.__C.BATCH_SIZE,
shuffle=True,
num_workers=self.__C.NUM_WORKERS,
pin_memory=self.__C.PIN_MEM,
drop_last=True
)
# Training script
for epoch in range(start_epoch, self.__C.MAX_EPOCH):
# Save log information
logfile = open(
self.__C.LOG_PATH +
'log_run_' + self.__C.VERSION + '.txt',
'a+'
)
logfile.write(
'nowTime: ' +
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') +
'\n'
)
logfile.close()
# Learning Rate Decay
if epoch in self.__C.LR_DECAY_LIST:
adjust_lr(optim, self.__C.LR_DECAY_R)
# Externally shuffle
if self.__C.SHUFFLE_MODE == 'external':
shuffle_list(dataset.ans_list)
time_start = time.time()
# Iteration
for step, (
ques_ix_iter,
frames_feat_iter,
clips_feat_iter,
ans_iter,
_,
_,
_,
_
) in enumerate(dataloader):
ques_ix_iter = ques_ix_iter.cuda()
frames_feat_iter = frames_feat_iter.cuda()
clips_feat_iter = clips_feat_iter.cuda()
ans_iter = ans_iter.cuda()
optim.zero_grad()
for accu_step in range(self.__C.GRAD_ACCU_STEPS):
sub_frames_feat_iter = \
frames_feat_iter[accu_step * self.__C.SUB_BATCH_SIZE:
(accu_step + 1) * self.__C.SUB_BATCH_SIZE]
sub_clips_feat_iter = \
clips_feat_iter[accu_step * self.__C.SUB_BATCH_SIZE:
(accu_step + 1) * self.__C.SUB_BATCH_SIZE]
sub_ques_ix_iter = \
ques_ix_iter[accu_step * self.__C.SUB_BATCH_SIZE:
(accu_step + 1) * self.__C.SUB_BATCH_SIZE]
sub_ans_iter = \
ans_iter[accu_step * self.__C.SUB_BATCH_SIZE:
(accu_step + 1) * self.__C.SUB_BATCH_SIZE]
pred = net(
sub_frames_feat_iter,
sub_clips_feat_iter,
sub_ques_ix_iter
)
loss = loss_fn(pred, sub_ans_iter)
# only mean-reduction needs be divided by grad_accu_steps
# removing this line wouldn't change our results because the speciality of Adam optimizer,
# but would be necessary if you use SGD optimizer.
# loss /= self.__C.GRAD_ACCU_STEPS
# start_backward = time.time()
loss.backward()
if self.__C.VERBOSE:
if dataset_eval is not None:
mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['val']
else:
mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['test']
# logging
self.writer.add_scalar(
'train/loss',
loss.cpu().data.numpy() / self.__C.SUB_BATCH_SIZE,
global_step=step + epoch * math.ceil(data_size / self.__C.BATCH_SIZE))
self.writer.add_scalar(
'train/lr',
optim._rate,
global_step=step + epoch * math.ceil(data_size / self.__C.BATCH_SIZE))
print("\r[exp_name %s][version %s][epoch %2d][step %4d/%4d][%s] loss: %.4f, lr: %.2e" % (
self.__C.EXP_NAME,
self.__C.VERSION,
epoch + 1,
step,
int(data_size / self.__C.BATCH_SIZE),
mode_str,
loss.cpu().data.numpy() / self.__C.SUB_BATCH_SIZE,
optim._rate,
), end=' ')
# Gradient norm clipping
if self.__C.GRAD_NORM_CLIP > 0:
nn.utils.clip_grad_norm_(
net.parameters(),
self.__C.GRAD_NORM_CLIP
)
# Save the gradient information
for name in range(len(named_params)):
norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \
if named_params[name][1].grad is not None else 0
grad_norm[name] += norm_v * self.__C.GRAD_ACCU_STEPS
optim.step()
time_end = time.time()
print('Finished in {}s'.format(int(time_end-time_start)))
epoch_finish = epoch + 1
# Save checkpoint
state = {
'state_dict': net.state_dict(),
'optimizer': optim.optimizer.state_dict(),
'lr_base': optim.lr_base,
'optim': optim.lr_base, }
torch.save(
state,
self.__C.CKPTS_PATH +
'ckpt_' + self.__C.VERSION +
'/epoch' + str(epoch_finish) +
'.pkl'
)
# Logging
logfile = open(
self.__C.LOG_PATH +
'log_run_' + self.__C.VERSION + '.txt',
'a+'
)
logfile.write(
'epoch = ' + str(epoch_finish) +
' loss = ' + str(loss_sum / data_size) +
'\n' +
'lr = ' + str(optim._rate) +
'\n\n'
)
logfile.close()
# Eval after every epoch
if dataset_eval is not None:
self.eval(
net,
dataset_eval,
self.writer,
epoch,
valid=True,
)
loss_sum = 0
grad_norm = np.zeros(len(named_params))
# Evaluation
def eval(self, net, dataset, writer, epoch, valid=False):
ans_ix_list = []
pred_list = []
q_type_list = []
q_bin_list = []
ans_rarity_list = []
ans_qtype_dict = {'what': [], 'who': [], 'how': [], 'when': [], 'where': []}
pred_qtype_dict = {'what': [], 'who': [], 'how': [], 'when': [], 'where': []}
ans_qlen_bin_dict = {'1-3': [], '4-8': [], '9-15': []}
pred_qlen_bin_dict = {'1-3': [], '4-8': [], '9-15': []}
ans_ans_rarity_dict = {'0-99': [], '100-299': [], '300-999': []}
pred_ans_rarity_dict = {'0-99': [], '100-299': [], '300-999': []}
data_size = dataset.data_size
net.eval()
if self.__C.N_GPU > 1:
net = nn.DataParallel(net, device_ids=self.__C.DEVICES)
dataloader = Data.DataLoader(
dataset,
batch_size=self.__C.EVAL_BATCH_SIZE,
shuffle=False,
num_workers=self.__C.NUM_WORKERS,
pin_memory=True
)
for step, (
ques_ix_iter,
frames_feat_iter,
clips_feat_iter,
_,
ans_iter,
q_type,
qlen_bin,
ans_rarity
) in enumerate(dataloader):
print("\rEvaluation: [step %4d/%4d]" % (
step,
int(data_size / self.__C.EVAL_BATCH_SIZE),
), end=' ')
ques_ix_iter = ques_ix_iter.cuda()
frames_feat_iter = frames_feat_iter.cuda()
clips_feat_iter = clips_feat_iter.cuda()
with torch.no_grad():
pred = net(
frames_feat_iter,
clips_feat_iter,
ques_ix_iter
)
pred_np = pred.cpu().data.numpy()
pred_argmax = np.argmax(pred_np, axis=1)
pred_list.extend(pred_argmax)
ans_ix_list.extend(ans_iter.tolist())
q_type_list.extend(q_type.tolist())
q_bin_list.extend(qlen_bin.tolist())
ans_rarity_list.extend(ans_rarity.tolist())
print('')
assert len(pred_list) == len(ans_ix_list) == len(q_type_list) == len(q_bin_list) == len(ans_rarity_list)
pred_list = [dataset.ix_to_ans[pred] for pred in pred_list]
ans_ix_list = [dataset.ix_to_ans[ans] for ans in ans_ix_list]
# Run validation script
scores_per_qtype = {
'what': {},
'who': {},
'how': {},
'when': {},
'where': {},
}
scores_per_qlen_bin = {
'1-3': {},
'4-8': {},
'9-15': {},
}
scores_ans_rarity_dict = {
'0-99': {},
'100-299': {},
'300-999': {}
}
if valid:
# create vqa object and vqaRes object
for pred, ans, q_type in zip(pred_list, ans_ix_list, q_type_list):
pred_qtype_dict[dataset.idx_to_qtypes[q_type]].append(pred)
ans_qtype_dict[dataset.idx_to_qtypes[q_type]].append(ans)
print('----------------- Computing scores -----------------')
acc = get_acc(ans_ix_list, pred_list)
print('----------------- Overall -----------------')
print('acc: {}'.format(acc))
writer.add_scalar('acc/overall', acc, global_step=epoch)
for q_type in scores_per_qtype:
print('----------------- Computing "{}" q-type scores -----------------'.format(q_type))
# acc, wups_0, wups_1 = get_scores(
# ans_ix_dict[q_type], pred_ix_dict[q_type])
acc = get_acc(ans_qtype_dict[q_type], pred_qtype_dict[q_type])
print('acc: {}'.format(acc))
writer.add_scalar(
'acc/{}'.format(q_type), acc, global_step=epoch)
else:
for pred, ans, q_type, qlen_bin, a_rarity in zip(
pred_list, ans_ix_list, q_type_list, q_bin_list, ans_rarity_list):
pred_qtype_dict[dataset.idx_to_qtypes[q_type]].append(pred)
ans_qtype_dict[dataset.idx_to_qtypes[q_type]].append(ans)
pred_qlen_bin_dict[dataset.idx_to_qlen_bins[qlen_bin]].append(pred)
ans_qlen_bin_dict[dataset.idx_to_qlen_bins[qlen_bin]].append(ans)
pred_ans_rarity_dict[dataset.idx_to_ans_rare[a_rarity]].append(pred)
ans_ans_rarity_dict[dataset.idx_to_ans_rare[a_rarity]].append(ans)
print('----------------- Computing overall scores -----------------')
acc = get_acc(ans_ix_list, pred_list)
print('----------------- Overall -----------------')
print('acc:{}'.format(acc))
print('----------------- Computing q-type scores -----------------')
for q_type in scores_per_qtype:
acc = get_acc(ans_qtype_dict[q_type], pred_qtype_dict[q_type])
print(' {} '.format(q_type))
print('acc:{}'.format(acc))
print('----------------- Computing qlen-bins scores -----------------')
for qlen_bin in scores_per_qlen_bin:
acc = get_acc(ans_qlen_bin_dict[qlen_bin], pred_qlen_bin_dict[qlen_bin])
print(' {} '.format(qlen_bin))
print('acc:{}'.format(acc))
print('----------------- Computing ans-rarity scores -----------------')
for a_rarity in scores_ans_rarity_dict:
acc = get_acc(ans_ans_rarity_dict[a_rarity], pred_ans_rarity_dict[a_rarity])
print(' {} '.format(a_rarity))
print('acc:{}'.format(acc))
net.train()
def construct_net(self, model_type):
if model_type == 1:
net = Net1(
self.__C,
self.dataset.pretrained_emb,
self.dataset.token_size,
self.dataset.ans_size
)
elif model_type == 2:
net = Net2(
self.__C,
self.dataset.pretrained_emb,
self.dataset.token_size,
self.dataset.ans_size
)
elif model_type == 3:
net = Net3(
self.__C,
self.dataset.pretrained_emb,
self.dataset.token_size,
self.dataset.ans_size
)
elif model_type == 4:
net = Net4(
self.__C,
self.dataset.pretrained_emb,
self.dataset.token_size,
self.dataset.ans_size
)
else:
raise ValueError('Net{} is not supported'.format(model_type))
return net
def run(self, run_mode, epoch=None):
self.set_seed(self.__C.SEED)
if run_mode == 'train':
self.empty_log(self.__C.VERSION)
self.train(self.dataset, self.dataset_eval)
elif run_mode == 'val':
self.eval(self.dataset, valid=True)
elif run_mode == 'test':
net = self.construct_net(self.__C.MODEL_TYPE)
assert epoch is not None
path = self.__C.CKPTS_PATH + \
'ckpt_' + self.__C.VERSION + \
'/epoch' + str(epoch) + '.pkl'
print('Loading ckpt {}'.format(path))
state_dict = torch.load(path)['state_dict']
net.load_state_dict(state_dict)
net.cuda()
self.eval(net, self.dataset_test, self.writer, 0)
else:
exit(-1)
def set_seed(self, seed):
"""Sets the seed for reproducibility.
Args:
seed (int): The seed used
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
print('\nSeed set to {}...\n'.format(seed))
def empty_log(self, version):
print('Initializing log file ........')
if (os.path.exists(self.__C.LOG_PATH + 'log_run_' + version + '.txt')):
os.remove(self.__C.LOG_PATH + 'log_run_' + version + '.txt')
print('Finished!')
print('')