neuro-symbolic-visual-dialog/prog_generator/train_question_parser.py

913 lines
38 KiB
Python

"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
import os
import sys
import json, torch, pickle, copy, time
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
from tensorboardX import SummaryWriter
from copy import deepcopy
from clevrDialog_dataset import ClevrDialogQuestionDataset
import pickle
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from executor.symbolic_executor import SymbolicExecutorClevr, SymbolicExecutorMinecraft
from models import SeqToSeqQ, QuestEncoder_1, QuestEncoder_2, Decoder, CaptionEncoder, SeqToSeqC
from optim import get_optim, adjust_lr
from options_caption_parser import Options as OptionsC
from options_question_parser import Options as OptionsQ
class Execution:
def __init__(self, optsQ, optsC):
self.opts = deepcopy(optsQ)
if self.opts.useCuda > 0 and torch.cuda.is_available():
self.device = torch.device("cuda:0")
print("[INFO] Using GPU {} ...".format(torch.cuda.get_device_name(0)))
else:
print("[INFO] Using CPU ...")
self.device = torch.device("cpu")
self.loss_fn = torch.nn.NLLLoss().to(self.device)
print("[INFO] Loading dataset ...")
self.datasetTr = ClevrDialogQuestionDataset(
self.opts.dataPathTr, self.opts.vocabPath, "train", "All tr data")
self.datasetVal = ClevrDialogQuestionDataset(
self.opts.dataPathVal, self.opts.vocabPath, "val", "All val data", train=False)
self.datasetTest = ClevrDialogQuestionDataset(
self.opts.dataPathTest, self.opts.vocabPath, "test", "All val data", train=False)
self.QuestionNet = constructQuestionNet(
self.opts,
self.datasetTr.lenVocabText,
self.datasetTr.lenVocabProg,
self.datasetTr.maxLenProg,
)
if os.path.isfile(self.opts.captionNetPath):
self.CaptionNet = constructCaptionNet(
optsC,
self.datasetTr.lenVocabText,
self.datasetTr.lenVocabProg,
self.datasetTr.maxLenProg
)
print('Loading CaptionNet from {}'.format(self.opts.captionNetPath))
state_dict = torch.load(self.opts.captionNetPath)['state_dict']
self.CaptionNet.load_state_dict(state_dict)
self.CaptionNet.to(self.device)
total_params_cap = sum(p.numel() for p in self.CaptionNet.parameters() if p.requires_grad)
print("The caption encoder has {} trainable parameters".format(total_params_cap))
self.QuestionNet.to(self.device)
# if os.path.isfile(self.opts.load_checkpoint_path):
# print('Loading QuestionNet from {}'.format(optsQ.load_checkpoint_path))
# state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
# self.QuestionNet.load_state_dict(state_dict)
total_params_quest = sum(p.numel() for p in self.QuestionNet.parameters() if p.requires_grad)
print("The question encoder has {} trainable parameters".format(total_params_quest))
if "minecraft" in self.opts.scenesPath:
self.symbolicExecutor = SymbolicExecutorMinecraft(self.opts.scenesPath)
else:
self.symbolicExecutor = SymbolicExecutorClevr(self.opts.scenesPath)
tb_path = os.path.join(self.opts.run_dir, "tb_logdir")
if not os.path.isdir(tb_path):
os.makedirs(tb_path)
self.ckpt_path = os.path.join(self.opts.run_dir, "ckpt_dir")
if not os.path.isdir(self.ckpt_path):
os.makedirs(self.ckpt_path)
if not os.path.isdir(self.opts.text_log_dir):
os.makedirs(self.opts.text_log_dir)
self.writer = SummaryWriter(tb_path)
self.iter_val = 0
if os.path.isfile(self.opts.dependenciesPath):
with open(self.opts.dependenciesPath, "rb") as f:
self.dependencies = pickle.load(f)
def train(self):
self.QuestionNet.train()
# Define the multi-gpu training if needed
if len(self.opts.gpu_ids) > 1:
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
# Load checkpoint if resume training
if os.path.isfile(self.opts.load_checkpoint_path):
print("[INFO] Resume trainig from ckpt {} ...".format(
self.opts.load_checkpoint_path
))
# Load the network parameters
ckpt = torch.load(self.opts.load_checkpoint_path)
print("[INFO] Checkpoint successfully loaded ...")
self.QuestionNet.load_state_dict(ckpt['state_dict'])
# Load the optimizer paramters
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr)) # , ckpt['optim'], lr_base=ckpt['lr_base'])
# optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH)
optim.optimizer.load_state_dict(ckpt['optimizer'])
_iter = 0 # ckpt['last_iter']
epoch = 0 # ckpt['last_epoch']
else:
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr))
_iter = 0
epoch = 0
trainTime = 0
bestValAcc = float("-inf")
bestCkp = 0
# Training loop
while _iter < self.opts.num_iters:
# Learning Rate Decay
if _iter in self.opts.lr_decay_marks:
adjust_lr(optim, self.opts.lr_decay_factor)
# Define multi-thread dataloader
dataloader = Data.DataLoader(
self.datasetTr,
batch_size=self.opts.batch_size,
shuffle=self.opts.shuffle_data,
num_workers=self.opts.num_workers,
)
# Iteration
time_start = 0
time_end = 0
for batch_iter, (quest, hist, prog, questionRound, _) in enumerate(dataloader):
time_start = time.time()
if _iter >= self.opts.num_iters:
break
quest = quest.to(self.device)
if self.opts.last_n_rounds < 10:
last_n_rounds_batch = []
for i, r in enumerate(questionRound.tolist()):
startIdx = max(r - self.opts.last_n_rounds, 0)
endIdx = max(r, self.opts.last_n_rounds)
if hist.dim() == 3:
assert endIdx - startIdx == self.opts.last_n_rounds
histBatch = hist[i, :, :]
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
elif hist.dim() == 2:
startIdx *= 20
endIdx *= 20
histBatch = hist[i, :]
temp = histBatch[startIdx:endIdx].cpu()
if r > self.opts.last_n_rounds:
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
else:
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
hist = torch.stack(last_n_rounds_batch, dim=0)
hist = hist.to(self.device)
prog = prog.to(self.device)
progTarget = prog.clone()
optim.zero_grad()
predSoftmax, _ = self.QuestionNet(quest, hist, prog[:, :-1])
loss = self.loss_fn(
# predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
predSoftmax.contiguous().view(-1, predSoftmax.size(2)),
progTarget[:, 1:].contiguous().view(-1))
loss.backward()
if _iter % self.opts.validate_every == 0 and _iter > 0:
valAcc = self.val()
if valAcc > bestValAcc:
bestValAcc = valAcc
bestCkp = _iter
print("\n[INFO] Checkpointing model @ iter {} with val accuracy {}\n".format(_iter, valAcc))
state = {
'state_dict': self.QuestionNet.state_dict(),
'optimizer': optim.optimizer.state_dict(),
'lr_base': optim.lr_base,
'optim': optim.lr_base,
'last_iter': _iter,
'last_epoch': epoch,
}
# checkpointing
torch.save(
state,
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
)
# logging
self.writer.add_scalar(
'train/loss',
loss.cpu().data.numpy(),
global_step=_iter)
self.writer.add_scalar(
'train/lr',
optim._rate,
global_step=_iter)
if _iter % self.opts.display_every == 0:
time_end = time.time()
trainTime += time_end-time_start
print("\r[CLEVR-Dialog - %s (%d | %d)][epoch %2d][iter %4d/%4d][runtime %4f] loss: %.4f, lr: %.2e" % (
self.datasetTr.name,
batch_iter,
len(dataloader),
epoch,
_iter,
self.opts.num_iters,
trainTime,
loss.cpu().data.numpy(),
optim._rate,
), end=' ')
optim.step()
_iter += 1
epoch += 1
print("[INFO] Avg. epoch time: {} s".format(trainTime / epoch))
print("[INFO] Best model achieved val acc. {} @ iter {}".format(bestValAcc, bestCkp))
def val(self):
self.QuestionNet.eval()
total_correct = 0
total = 0
if len(self.opts.gpu_ids) > 1:
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
self.QuestionNet = self.QuestionNet.eval()
dataloader = Data.DataLoader(
self.datasetVal,
batch_size=self.opts.batch_size,
shuffle=True,
num_workers=self.opts.num_workers,
pin_memory=False
)
_iterCur = 0
_totalCur = len(dataloader)
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
# print("\rEvaluation: [step %4d/%4d]" % (
print("\rEvaluation: [step %4d/%4d]" % (
step,
int(len(dataloader)),
), end=' ')
question = question.to(self.device)
if history.dim() == 3:
caption = history.detach()
caption = caption[:, 0, :]
caption = caption[:, :16].to(self.device)
elif history.dim() == 2:
caption = history.detach()
caption = caption[:, :16].to(self.device)
if self.opts.last_n_rounds is not None:
last_n_rounds_batch = []
for i, r in enumerate(questionRounds.tolist()):
startIdx = max(r - self.opts.last_n_rounds, 0)
endIdx = max(r, self.opts.last_n_rounds)
if history.dim() == 3:
assert endIdx - startIdx == self.opts.last_n_rounds
histBatch = history[i, :, :]
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
elif history.dim() == 2:
startIdx *= 20
endIdx *= 20
histBatch = history[i, :]
temp = histBatch[startIdx:endIdx]
if r > self.opts.last_n_rounds:
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
else:
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
history = torch.stack(last_n_rounds_batch, dim=0)
history = history.to(self.device)
questionPrg = questionPrg.to(self.device)
questProgsToksPred = self.QuestionNet.sample(question, history)
questProgsPred = decodeProg(questProgsToksPred, self.datasetVal.vocab["idx_prog_to_token"])
targetProgs = decodeProg(questionPrg, self.datasetVal.vocab["idx_prog_to_token"], target=True)
correct = [1 if pred == gt else 0 for (pred, gt) in zip(questProgsPred, targetProgs)]
correct = sum(correct)
total_correct += correct
total += len(targetProgs)
self.QuestionNet.train()
return 100.0 * (total_correct / total)
# Evaluation
def eval_with_gt(self):
# Define the multi-gpu training if needed
all_pred_answers = []
all_gt_answers = []
all_question_types = []
all_penalties = []
all_pred_programs = []
all_gt_programs = []
first_failure_round = 0
total_correct = 0
total_acc_pen = 0
total = 0
total_quest_prog_correct = 0
if len(self.opts.gpu_ids) > 1:
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
self.QuestionNet = self.QuestionNet.eval()
self.CaptionNet = self.CaptionNet.eval()
if self.opts.batch_size != self.opts.dialogLen:
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
self.opts.batch_size = self.opts.dialogLen
dataloader = Data.DataLoader(
self.datasetTest,
batch_size=self.opts.batch_size,
shuffle=False,
num_workers=self.opts.num_workers,
pin_memory=False
)
_iterCur = 0
_totalCur = len(dataloader)
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
# print("\rEvaluation: [step %4d/%4d]" % (
# step + 1,
# int(data_size / self.opts.batch_size),
# ), end=' ')
# if step >= 5000:
# break
batchSize = question.size(0)
question = question.to(self.device)
# dependecy = self.dependencies[step*batchSize:(step+1)*batchSize]
if history.dim() == 3:
caption = history.detach()
caption = caption[:, 0, :]
caption = caption[:, :16].to(self.device)
elif history.dim() == 2:
caption = history.detach()
caption = caption[:, :16].to(self.device)
if self.opts.last_n_rounds < 10:
last_n_rounds_batch = []
for i, r in enumerate(questionRounds.tolist()):
startIdx = max(r - self.opts.last_n_rounds, 0)
endIdx = max(r, self.opts.last_n_rounds)
if history.dim() == 3:
assert endIdx - startIdx == self.opts.last_n_rounds
histBatch = history[i, :, :]
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
elif history.dim() == 2:
startIdx *= 20
endIdx *= 20
histBatch = history[i, :]
temp = histBatch[startIdx:endIdx]
if r > self.opts.last_n_rounds:
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
else:
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
history = torch.stack(last_n_rounds_batch, dim=0)
history = history.to(self.device)
questionPrg = questionPrg.to(self.device)
historiesProg = historiesProg.tolist()
questionRounds = questionRounds.tolist()
answer = answer.tolist()
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
questionImgIdx = questionImgIdx.tolist()
# if "minecraft" in self.opts.scenesPath:
# questionImgIdx = [idx - 1 for idx in questionImgIdx]
questProgsToksPred = self.QuestionNet.sample(question, history)
capProgsToksPred = self.CaptionNet.sample(caption)
questProgsPred = decodeProg(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
capProgsPred = decodeProg(capProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
targetProgs = decodeProg(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
questionTypes = [targetProg[0] for targetProg in targetProgs]
# progHistories = getProgHistories(historiesProg[0], dataset.vocab["idx_prog_to_token"])
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
pred_answers = []
all_pred_programs.append([capProgsPred[0]] + questProgsPred)
all_gt_programs.append([progHistories[0]] + (targetProgs))
for i in range(batchSize):
# if capProgsPred[i][0] == "extreme-center":
# print("bla")
# print("idx = {}".format(questionImgIdx[i]))
ans = self.getPrediction(
questProgsPred[i],
capProgsPred[i],
progHistories[i],
questionImgIdx[i]
)
# if ans == "Error":
# print(capProgsPred[i])
pred_answers.append(ans)
# print(pred_answers)
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
correct_prog = [1 if pred == ans else 0 for (pred, ans) in zip(questProgsPred, targetProgs)]
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
if idx_false.shape[-1] > 0:
first_failure_round += idx_false[0] + 1
else:
first_failure_round += self.opts.dialogLen + 1
correct = sum(correct)
correct_prog = sum(correct_prog)
total_correct += correct
total_quest_prog_correct += correct_prog
total += len(answers)
all_pred_answers.append(pred_answers)
all_gt_answers.append(answers)
all_question_types.append(questionTypes)
penalty = np.zeros_like(penalty)
all_penalties.append(penalty)
_iterCur += 1
if _iterCur % self.opts.display_every == 0:
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
_iterCur, _totalCur, 100.0 * (total_correct / total)))
ffr = 1.0 * (first_failure_round/_totalCur)/(self.opts.dialogLen + 1)
textOut = "\n --------------- Average First Failure Round --------------- \n"
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
# print(total_correct, total)
accuracy = total_correct / total
vd_acc = total_acc_pen / total
quest_prog_acc = total_quest_prog_correct / total
textOut += "\n --------------- Overall acc. --------------- \n"
textOut += "{}".format(100.0 * accuracy)
textOut += "\n --------------- Overall VD acc. --------------- \n"
textOut += "{}".format(100.0 * vd_acc)
textOut += "\n --------------- Question Prog. Acc --------------- \n"
textOut += "{}".format(100.0 * quest_prog_acc)
textOut += get_per_round_acc(
all_pred_answers, all_gt_answers, all_penalties)
textOut += get_per_question_type_acc(
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
# textOut += get_per_dependency_type_acc(
# all_pred_answers, all_gt_answers, all_penalties)
textOut += "\n --------------- Done --------------- \n"
print(textOut)
fname = self.opts.questionNetPath.split("/")[-3] + "results_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen)
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
model_name = "NSVD_stack" if "stack" in self.opts.questionNetPath else "NSVD_concat"
experiment_name = "minecraft"
# experiment_name += "_{}".format(self.opts.dialogLen)
prog_output_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/prog_output/{}_{}.pkl".format(model_name, experiment_name))
fpath = os.path.join(self.opts.text_log_dir, fname)
with open(fpath, "w") as f:
f.writelines(textOut)
with open(pred_answers_fname, "wb") as f:
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
with open(prog_output_fname, "wb") as f:
pickle.dump((all_gt_programs, all_pred_programs, all_pred_answers), f, protocol=pickle.HIGHEST_PROTOCOL)
# Evaluation
def eval_with_pred(self):
# Define the multi-gpu training if needed
all_pred_answers = []
all_gt_answers = []
all_question_types = []
all_penalties = []
first_failure_round = 0
total_correct = 0
total_acc_pen = 0
total = 0
samples = {}
if len(self.opts.gpu_ids) > 1:
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
self.QuestionNet = self.QuestionNet.eval()
self.CaptionNet = self.CaptionNet.eval()
if self.opts.batch_size != self.opts.dialogLen:
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
self.opts.batch_size = self.opts.dialogLen
dataloader = Data.DataLoader(
self.datasetTest,
batch_size=self.opts.batch_size,
shuffle=False,
num_workers=self.opts.num_workers,
pin_memory=False
)
_iterCur = 0
_totalCur = len(dataloader)
step = 0
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
question = question.tolist()
questions = decode(question, self.datasetTest.vocab["idx_text_to_token"], target=True)
questions = list(map(lambda q: " ".join(q), questions))
targetProgs = decode(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
questionTypes = [targetProg[0] for targetProg in targetProgs]
targetProgs = list(map(lambda q: " ".join(q), targetProgs))
historiesProg = historiesProg.tolist()
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
answer = answer.tolist()
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
questionImgIdx = questionImgIdx.tolist()
if self.opts.encoderType == 2:
histories_eval = [history[0, 0, :].tolist()]
caption = history.detach()
caption = caption[0, 0, :].unsqueeze(0)
caption = caption[:, :16].to(self.device)
elif self.opts.encoderType == 1:
caption = history.detach()
histories_eval = [history[0, :20].tolist()]
caption = caption[0, :16].unsqueeze(0).to(self.device)
cap = decode(caption, self.datasetTest.vocab["idx_text_to_token"], target=False)
capProgToksPred = self.CaptionNet.sample(caption)
capProgPred = decode(capProgToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
pred_answers = []
pred_quest_prog = []
for i, (q, prog_hist, img_idx) in enumerate(zip(question, progHistories, questionImgIdx)):
_round = i + 1
if _round <= self.opts.last_n_rounds:
start = 0
else:
start = _round - self.opts.last_n_rounds
end = len(histories_eval)
quest = torch.tensor(q).unsqueeze(0).to(self.device)
if self.opts.encoderType == 3:
hist = torch.stack([torch.tensor(h) for h in histories_eval[start:end]], dim=0).unsqueeze(0).to(self.device)
elif self.opts.encoderType == 1:
histories_eval_copy = deepcopy(histories_eval)
histories_eval_copy[-1].append(self.datasetTest.vocab["text_token_to_idx"]["<END>"])
hist = torch.cat([torch.tensor(h) for h in histories_eval_copy[start:end]], dim=-1).unsqueeze(0).to(self.device)
questProgsToksPred = self.QuestionNet.sample(quest, hist)
questProgsPred = decode(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
pred_quest_prog.append(" ".join(questProgsPred))
ans = self.getPrediction(
questProgsPred,
capProgPred,
prog_hist,
img_idx
)
ans_idx = self.datasetTest.vocab["text_token_to_idx"].get(
ans, self.datasetTest.vocab["text_token_to_idx"]["<UNK>"])
q[q.index(self.datasetTest.vocab["text_token_to_idx"]["<END>"])] = self.datasetTest.vocab["text_token_to_idx"]["<NULL>"]
q[-1] = self.datasetTest.vocab["text_token_to_idx"]["<END>"]
q.insert(-1, ans_idx)
if self.opts.encoderType == 3:
histories_eval.append(copy.deepcopy(q))
elif self.opts.encoderType == 0:
del q[0]
del q[-1]
histories_eval.append(copy.deepcopy(q))
pred_answers.append(ans)
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
if idx_false.shape[-1] > 0:
first_failure_round += idx_false[0] + 1
else:
first_failure_round += self.opts.dialogLen + 1
correct = sum(correct)
total_correct += correct
total += len(answers)
all_pred_answers.append(pred_answers)
all_gt_answers.append(answers)
all_question_types.append(questionTypes)
_iterCur += 1
if _iterCur % self.opts.display_every == 0:
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
_iterCur, _totalCur, 100.0 * (total_correct / total)
))
samples["{}_{}".format(questionImgIdx[0], (step % 5) + 1)] = {
"caption": " ".join(cap[0]),
"cap_prog_gt": " ".join(progHistories[0][0]),
"cap_prog_pred": " ".join(capProgPred),
"questions": questions,
"quest_progs_gt": targetProgs,
"quest_progs_pred": pred_quest_prog,
"answers": answers,
"preds": pred_answers,
"acc": correct,
}
ffr = 1.0 * self.opts.dialogLen * (first_failure_round/total)
textOut = "\n --------------- Average First Failure Round --------------- \n"
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
# print(total_correct, total)
accuracy = total_correct / total
vd_acc = total_acc_pen / total
textOut += "\n --------------- Overall acc. --------------- \n"
textOut += "{}".format(100.0 * accuracy)
textOut += "\n --------------- Overall VD acc. --------------- \n"
textOut += "{}".format(100.0 * vd_acc)
textOut += get_per_round_acc(
all_pred_answers, all_gt_answers, all_penalties)
textOut += get_per_question_type_acc(
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
textOut += "\n --------------- Done --------------- \n"
print(textOut)
if step >= len(dataloader):
fname = self.opts.questionNetPath.split("/")[-3] + "_results_{}_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen, self.acc_type)
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
fpath = os.path.join(self.opts.text_log_dir, fname)
with open(fpath, "w") as f:
f.writelines(textOut)
with open(pred_answers_fname, "wb") as f:
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
def getPrediction(self, questProgPred, capProgPred, historyProg, imgIndex):
self.symbolicExecutor.reset(imgIndex)
# if round one, execute the predicted caption program first then answer the question
if len(historyProg) == 1:
captionFuncLabel = capProgPred[0]
captionFuncArgs = capProgPred[1:]
questionFuncLabel = questProgPred[0]
questionFuncArgs = questProgPred[1:]
try:
_ = self.symbolicExecutor.execute(captionFuncLabel, captionFuncArgs)
except:
return "Error"
try:
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
except:
return "Error"
# If it is not the first round, we have to execute the program history and
# then answer the question.
else:
questionFuncLabel = questProgPred[0]
questionFuncArgs = questProgPred[1:]
for prg in historyProg:
# prg = prg.split(" ")
FuncLabel = prg[0]
FuncArgs = prg[1:]
try:
_ = self.symbolicExecutor.execute(FuncLabel, FuncArgs)
except:
return "Error"
try:
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
except:
return "Error"
return str(predAnswer)
def run(self, run_mode, epoch=None):
self.set_seed(self.opts.seed)
if run_mode == 'train':
self.train()
elif run_mode == 'test_with_gt':
print('Testing with gt answers in history')
print('Loading ckpt {}'.format(self.opts.questionNetPath))
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
self.QuestionNet.load_state_dict(state_dict)
self.eval_with_gt()
elif run_mode == 'test_with_pred':
print('Testing with predicted answers in history')
print('Loading ckpt {}'.format(self.opts.questionNetPath))
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
self.QuestionNet.load_state_dict(state_dict)
self.eval_with_pred()
else:
exit(-1)
def set_seed(self, seed):
"""Sets the seed for reproducibility.
Args:
seed (int): The seed used
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
print('[INFO] Seed set to {}...'.format(seed))
def constructQuestionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
decoder = Decoder(opts, lenVocabProg, maxLenProg)
if opts.encoderType == 1:
encoder = QuestEncoder_1(opts, lenVocabText)
elif opts.encoderType == 2:
encoder = QuestEncoder_2(opts, lenVocabText)
net = SeqToSeqQ(encoder, decoder)
return net
def constructCaptionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
decoder = Decoder(opts, lenVocabProg, maxLenProg)
encoder = CaptionEncoder(opts, lenVocabText)
net = SeqToSeqC(encoder, decoder)
return net
def getProgHistories(progHistToks, prgIdxToToken):
progHist = []
temp = []
for tok in progHistToks:
if tok not in [0, 1, 2]:
temp.append(prgIdxToToken[tok])
# del progHistToks[i]
if tok == 2:
# del progHistToks[i]
# progHist.append(" ".join(temp))
progHist.append(temp)
temp = []
return progHist
def getHistoriesFromStack(histToks, textIdxToToken):
histories = "\n"
temp = []
for i, roundToks in enumerate(histToks):
for tok in roundToks:
if tok not in [0, 1, 2]:
temp.append(textIdxToToken[tok])
# del progHistToks[i]
if tok == 2:
# del progHistToks[i]
if i == 0:
histories += " ".join(temp) + ".\n"
else:
histories += " ".join(temp[:-1]) + "? | {}.\n".format(temp[-1])
# histories.append(temp)
temp = []
break
return histories
def getHistoriesFromConcat(histToks, textIdxToToken):
histories = []
temp = []
for tok in histToks:
if tok not in [0, 1, 2]:
temp.append(textIdxToToken[tok])
# del progHistToks[i]
if tok == 2:
# del progHistToks[i]
histories.append(" ".join(temp[:-1]) + "? | {}".format(temp[-1]))
# histories.append(temp)
temp = []
return histories
def decodeProg(tokens, prgIdxToToken, target=False):
tokensBatch = tokens.tolist()
progsBatch = []
for tokens in tokensBatch:
prog = []
for tok in tokens:
if tok == 2: # <END> has index 2
break
prog.append(prgIdxToToken.get(tok))
if target:
prog = prog[1:]
# progsBatch.append(" ".join(prog))
progsBatch.append(prog)
return progsBatch
def printPred(predSoftmax, gts, prgIdxToToken):
assert predSoftmax.size(0) == gts.size(0)
tokens = predSoftmax.topk(1)[1].squeeze(-1)
tokens = tokens.tolist()
gts = gts.tolist()
message = "\n ------------------------ \n"
for token, gt in zip(tokens, gts):
message += "Prediction: "
for tok in token:
message += prgIdxToToken.get(tok) + " "
message += "\n Target : "
for tok in gt:
message += prgIdxToToken.get(tok) + " "
message += "\n ------------------------ \n"
return message
def get_per_round_acc(preds, gts, penalties):
res = {}
for img_preds, img_gt, img_pen in zip(preds, gts, penalties):
img_preds = list(img_preds)
img_gt = list(img_gt)
img_pen = list(img_pen)
for i, (pred, gt, pen) in enumerate(zip(img_preds, img_gt, img_pen)):
_round = str(i + 1)
if _round not in res:
res[_round] = {
"correct": 0,
"all": 0
}
res[_round]["all"] += 1
if pred == gt:
res[_round]["correct"] += 0.5**pen
textOut = "\n --------------- Per round Acc --------------- \n"
for k in res:
textOut += "{}: {} %\n".format(k, 100.0 * (res[k]["correct"]/res[k]["all"]))
return textOut
def get_per_question_type_acc(preds, gts, qtypes, penalties):
res1 = {}
res2 = {}
for img_preds, img_gt, img_qtypes, img_pen in zip(preds, gts, qtypes, penalties):
# img_preds = list(img_preds)
# img_gt = list(img_gt)
img_pen = list(img_pen)
for pred, gt, temp, pen in zip(img_preds, img_gt, img_qtypes, img_pen):
if temp not in res1:
res1[temp] = {
"correct": 0,
"all": 0
}
temp_cat = temp.split("-")[0]
if temp_cat not in res2:
res2[temp_cat] = {
"correct": 0,
"all": 0
}
res1[temp]["all"] += 1
res2[temp_cat]["all"] += 1
if pred == gt:
res1[temp]["correct"] += 0.5**pen
res2[temp_cat]["correct"] += 0.5**pen
textOut = "\n --------------- Per question Type Acc --------------- \n"
for k in res1:
textOut += "{}: {} %\n".format(k, 100.0 * (res1[k]["correct"]/res1[k]["all"]))
textOut += "\n --------------- Per question Category Acc --------------- \n"
for k in res2:
textOut += "{}: {} %\n".format(k, 100.0 * (res2[k]["correct"]/res2[k]["all"]))
return textOut
def decode(tokens, prgIdxToToken, target=False):
if type(tokens) != list:
tokens = tokens.tolist()
progsBatch = []
for token in tokens:
prog = []
for tok in token:
if tok == 2: # <END> has index 2
break
prog.append(prgIdxToToken.get(tok))
if target:
prog = prog[1:]
# progsBatch.append(" ".join(prog))
progsBatch.append(prog)
return progsBatch
if __name__ == "__main__":
optsC = OptionsC().parse()
optsQ = OptionsQ().parse()
exe = Execution(optsQ, optsC)
exe.run("test")
print("[INFO] Done ...")