280 lines
10 KiB
Python
280 lines
10 KiB
Python
"""
|
|
author: Adnen Abdessaied
|
|
maintainer: "Adnen Abdessaied"
|
|
website: adnenabdessaied.de
|
|
version: 1.0.1
|
|
"""
|
|
|
|
from clevrDialog_dataset import ClevrDialogCaptionDataset
|
|
from models import SeqToSeqC, CaptionEncoder, Decoder
|
|
from optim import get_optim, adjust_lr
|
|
from options_caption_parser import Options
|
|
import os, json, torch, pickle, copy, time
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.utils.data as Data
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
class Execution:
|
|
def __init__(self, opts):
|
|
self.opts = opts
|
|
|
|
self.loss_fn = torch.nn.NLLLoss().cuda()
|
|
|
|
print("[INFO] Loading dataset ...")
|
|
|
|
self.dataset_tr = ClevrDialogCaptionDataset(
|
|
opts.dataPathTr, opts.vocabPath, "train", "Captions Tr")
|
|
|
|
self.dataset_val = ClevrDialogCaptionDataset(
|
|
opts.dataPathVal, opts.vocabPath, "val", "Captions Val")
|
|
|
|
self.dataset_test = ClevrDialogCaptionDataset(
|
|
opts.dataPathTest, opts.vocabPath, "test", "Captions Test")
|
|
|
|
tb_path = os.path.join(opts.run_dir, "tb_logdir")
|
|
if not os.path.isdir(tb_path):
|
|
os.makedirs(tb_path)
|
|
|
|
self.ckpt_path = os.path.join(opts.run_dir, "ckpt_dir")
|
|
if not os.path.isdir(self.ckpt_path):
|
|
os.makedirs(self.ckpt_path)
|
|
|
|
self.writer = SummaryWriter(tb_path)
|
|
self.iter_val = 0
|
|
self.bestValAcc = float("-inf")
|
|
self.bestValIter = -1
|
|
|
|
def constructNet(self, lenVocabText, lenVocabProg, maxLenProg, ):
|
|
decoder = Decoder(self.opts, lenVocabProg, maxLenProg)
|
|
encoder = CaptionEncoder(self.opts, lenVocabText)
|
|
net = SeqToSeqC(encoder, decoder)
|
|
return net
|
|
|
|
def train(self, dataset, dataset_val=None):
|
|
# Obtain needed information
|
|
lenVocabText = dataset.lenVocabText
|
|
lenVocabProg = dataset.lenVocabProg
|
|
maxLenProg = dataset.maxLenProg
|
|
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
|
|
|
net.cuda()
|
|
net.train()
|
|
|
|
# Define the multi-gpu training if needed
|
|
if len(self.opts.gpu_ids) > 1:
|
|
net = nn.DataParallel(net, device_ids=self.opts.gpu_ids)
|
|
|
|
# Load checkpoint if resume training
|
|
if self.opts.load_checkpoint_path is not None:
|
|
print("[INFO] Resume trainig from ckpt {} ...".format(
|
|
self.opts.load_checkpoint_path
|
|
))
|
|
|
|
# Load the network parameters
|
|
ckpt = torch.load(self.opts.load_checkpoint_path)
|
|
print("[INFO] Checkpoint successfully loaded ...")
|
|
net.load_state_dict(ckpt['state_dict'])
|
|
|
|
# Load the optimizer paramters
|
|
optim = get_optim(self.opts, net, len(dataset), lr_base=ckpt['lr_base'])
|
|
optim.optimizer.load_state_dict(ckpt['optimizer'])
|
|
|
|
else:
|
|
optim = get_optim(self.opts, net, len(dataset))
|
|
_iter = 0
|
|
epoch = 0
|
|
|
|
# Define dataloader
|
|
dataloader = Data.DataLoader(
|
|
dataset,
|
|
batch_size=self.opts.batch_size,
|
|
shuffle=self.opts.shuffle_data,
|
|
num_workers=self.opts.num_workers,
|
|
)
|
|
_iterCur = 0
|
|
_totalCur = len(dataloader)
|
|
# Training loop
|
|
while _iter < self.opts.num_iters:
|
|
# Learning Rate Decay
|
|
if _iter in self.opts.lr_decay_marks:
|
|
adjust_lr(optim, self.opts.lr_decay_factor)
|
|
|
|
time_start = time.time()
|
|
# Iteration
|
|
for caption, captionPrg in dataloader:
|
|
if _iter >= self.opts.num_iters:
|
|
break
|
|
caption = caption.cuda()
|
|
captionPrg = captionPrg.cuda()
|
|
captionPrgTarget = captionPrg.clone()
|
|
optim.zero_grad()
|
|
|
|
predSoftmax, _ = net(caption, captionPrg)
|
|
|
|
loss = self.loss_fn(
|
|
predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
|
|
captionPrgTarget[:, 1:].contiguous().view(-1))
|
|
loss.backward()
|
|
|
|
# logging
|
|
self.writer.add_scalar(
|
|
'train/loss',
|
|
loss.cpu().data.numpy(),
|
|
global_step=_iter)
|
|
|
|
self.writer.add_scalar(
|
|
'train/lr',
|
|
optim._rate,
|
|
global_step=_iter)
|
|
if _iter % self.opts.display_every == 0:
|
|
print("\r[CLEVR-Dialog - %s (%d/%4d)][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" % (
|
|
dataset.name,
|
|
_iterCur,
|
|
_totalCur,
|
|
epoch,
|
|
_iter,
|
|
self.opts.num_iters,
|
|
loss.cpu().data.numpy(),
|
|
optim._rate,
|
|
), end=' ')
|
|
optim.step()
|
|
_iter += 1
|
|
_iterCur += 1
|
|
|
|
if _iter % self.opts.validate_every == 0:
|
|
if dataset_val is not None:
|
|
valAcc = self.eval(
|
|
net,
|
|
dataset_val,
|
|
valid=True,
|
|
)
|
|
if valAcc > self.bestValAcc:
|
|
self.bestValAcc = valAcc
|
|
self.bestValIter = _iter
|
|
|
|
print("[INFO] Checkpointing model @ iter {}".format(_iter))
|
|
state = {
|
|
'state_dict': net.state_dict(),
|
|
'optimizer': optim.optimizer.state_dict(),
|
|
'lr_base': optim.lr_base,
|
|
'optim': optim.lr_base,
|
|
'last_iter': _iter,
|
|
'last_epoch': epoch,
|
|
}
|
|
# checkpointing
|
|
torch.save(
|
|
state,
|
|
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
|
|
)
|
|
else:
|
|
print("[INFO] No validation dataset available")
|
|
|
|
time_end = time.time()
|
|
print('Finished epoch in {}s'.format(int(time_end-time_start)))
|
|
epoch += 1
|
|
|
|
print("[INFO] Training done. Best model had val acc. {} @ iter {}...".format(self.bestValAcc, self.bestValIter))
|
|
|
|
# Evaluation
|
|
def eval(self, net, dataset, valid=False):
|
|
net = net.eval()
|
|
data_size = len(dataset)
|
|
dataloader = Data.DataLoader(
|
|
dataset,
|
|
batch_size=self.opts.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.opts.num_workers,
|
|
pin_memory=False
|
|
)
|
|
allPredictedProgs = []
|
|
numAllProg = 0
|
|
falsePred = 0
|
|
for step, (caption, captionPrg) in enumerate(dataloader):
|
|
print("\rEvaluation: [step %4d/%4d]" % (
|
|
step,
|
|
int(data_size / self.opts.batch_size),
|
|
), end=' ')
|
|
caption = caption.cuda()
|
|
captionPrg = captionPrg.cuda()
|
|
tokens = net.sample(caption)
|
|
targetProgs = decodeProg(captionPrg, dataset.vocab["idx_prog_to_token"], target=True)
|
|
predProgs = decodeProg(tokens, dataset.vocab["idx_prog_to_token"])
|
|
allPredictedProgs.extend(list(map(lambda s: "( {} ( {} ) ) \n".format(s[0], ", ".join(s[1:])), predProgs)))
|
|
numAllProg += len(targetProgs)
|
|
for targetProg, predProg in zip(targetProgs, predProgs):
|
|
mainMod = targetProg[0] == predProg[0]
|
|
sameLength = len(targetProg) == len(predProg)
|
|
sameArgs = False
|
|
if sameLength:
|
|
sameArgs = True
|
|
for argTarget in targetProg[1:]:
|
|
if argTarget not in predProg[1:]:
|
|
sameArgs = False
|
|
break
|
|
|
|
if not (mainMod and sameArgs):
|
|
falsePred += 1
|
|
val_acc = (1 - (falsePred / numAllProg)) * 100.0
|
|
print("Acc: {}".format(val_acc))
|
|
net = net.train()
|
|
if not valid:
|
|
with open(self.opts.res_path, "w") as f:
|
|
f.writelines(allPredictedProgs)
|
|
print("[INFO] Predicted caption programs logged into {}".format(self.opts.res_path))
|
|
return val_acc
|
|
|
|
def run(self, run_mode):
|
|
self.set_seed(self.opts.seed)
|
|
if run_mode == 'train':
|
|
self.train(self.dataset_tr, self.dataset_val)
|
|
|
|
elif run_mode == 'test':
|
|
lenVocabText = self.dataset_test.lenVocabText
|
|
lenVocabProg = self.dataset_test.lenVocabProg
|
|
maxLenProg = self.dataset_test.maxLenProg
|
|
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
|
|
|
print('Loading ckpt {}'.format(self.opts.load_checkpoint_path))
|
|
state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
|
|
net.load_state_dict(state_dict)
|
|
net.cuda()
|
|
self.eval(net, self.dataset_test)
|
|
|
|
else:
|
|
exit(-1)
|
|
|
|
def set_seed(self, seed):
|
|
"""Sets the seed for reproducibility.
|
|
Args:
|
|
seed (int): The seed used
|
|
"""
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
np.random.seed(seed)
|
|
print('[INFO] Seed set to {}...'.format(seed))
|
|
|
|
|
|
def decodeProg(tokens, prgIdxToToken, target=False):
|
|
tokensBatch = tokens.tolist()
|
|
progsBatch = []
|
|
for tokens in tokensBatch:
|
|
prog = []
|
|
for tok in tokens:
|
|
if tok == 2: # <END> has index 2
|
|
break
|
|
prog.append(prgIdxToToken.get(tok))
|
|
if target:
|
|
prog = prog[1:]
|
|
progsBatch.append(prog)
|
|
return progsBatch
|
|
|
|
|
|
if __name__ == "__main__":
|
|
opts = Options().parse()
|
|
exe = Execution(opts)
|
|
exe.run(opts.mode)
|
|
print("[INFO] Done ...")
|