""" author: Adnen Abdessaied maintainer: "Adnen Abdessaied" website: adnenabdessaied.de version: 1.0.1 """ import os import sys import json, torch, pickle, copy, time import numpy as np import torch.nn as nn import torch.utils.data as Data from tensorboardX import SummaryWriter from copy import deepcopy from clevrDialog_dataset import ClevrDialogQuestionDataset import pickle from tqdm import tqdm sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from executor.symbolic_executor import SymbolicExecutorClevr, SymbolicExecutorMinecraft from models import SeqToSeqQ, QuestEncoder_1, QuestEncoder_2, Decoder, CaptionEncoder, SeqToSeqC from optim import get_optim, adjust_lr from options_caption_parser import Options as OptionsC from options_question_parser import Options as OptionsQ class Execution: def __init__(self, optsQ, optsC): self.opts = deepcopy(optsQ) if self.opts.useCuda > 0 and torch.cuda.is_available(): self.device = torch.device("cuda:0") print("[INFO] Using GPU {} ...".format(torch.cuda.get_device_name(0))) else: print("[INFO] Using CPU ...") self.device = torch.device("cpu") self.loss_fn = torch.nn.NLLLoss().to(self.device) print("[INFO] Loading dataset ...") self.datasetTr = ClevrDialogQuestionDataset( self.opts.dataPathTr, self.opts.vocabPath, "train", "All tr data") self.datasetVal = ClevrDialogQuestionDataset( self.opts.dataPathVal, self.opts.vocabPath, "val", "All val data", train=False) self.datasetTest = ClevrDialogQuestionDataset( self.opts.dataPathTest, self.opts.vocabPath, "test", "All val data", train=False) self.QuestionNet = constructQuestionNet( self.opts, self.datasetTr.lenVocabText, self.datasetTr.lenVocabProg, self.datasetTr.maxLenProg, ) if os.path.isfile(self.opts.captionNetPath): self.CaptionNet = constructCaptionNet( optsC, self.datasetTr.lenVocabText, self.datasetTr.lenVocabProg, self.datasetTr.maxLenProg ) print('Loading CaptionNet from {}'.format(self.opts.captionNetPath)) state_dict = torch.load(self.opts.captionNetPath)['state_dict'] self.CaptionNet.load_state_dict(state_dict) self.CaptionNet.to(self.device) total_params_cap = sum(p.numel() for p in self.CaptionNet.parameters() if p.requires_grad) print("The caption encoder has {} trainable parameters".format(total_params_cap)) self.QuestionNet.to(self.device) # if os.path.isfile(self.opts.load_checkpoint_path): # print('Loading QuestionNet from {}'.format(optsQ.load_checkpoint_path)) # state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict'] # self.QuestionNet.load_state_dict(state_dict) total_params_quest = sum(p.numel() for p in self.QuestionNet.parameters() if p.requires_grad) print("The question encoder has {} trainable parameters".format(total_params_quest)) if "minecraft" in self.opts.scenesPath: self.symbolicExecutor = SymbolicExecutorMinecraft(self.opts.scenesPath) else: self.symbolicExecutor = SymbolicExecutorClevr(self.opts.scenesPath) tb_path = os.path.join(self.opts.run_dir, "tb_logdir") if not os.path.isdir(tb_path): os.makedirs(tb_path) self.ckpt_path = os.path.join(self.opts.run_dir, "ckpt_dir") if not os.path.isdir(self.ckpt_path): os.makedirs(self.ckpt_path) if not os.path.isdir(self.opts.text_log_dir): os.makedirs(self.opts.text_log_dir) self.writer = SummaryWriter(tb_path) self.iter_val = 0 if os.path.isfile(self.opts.dependenciesPath): with open(self.opts.dependenciesPath, "rb") as f: self.dependencies = pickle.load(f) def train(self): self.QuestionNet.train() # Define the multi-gpu training if needed if len(self.opts.gpu_ids) > 1: self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids) # Load checkpoint if resume training if os.path.isfile(self.opts.load_checkpoint_path): print("[INFO] Resume trainig from ckpt {} ...".format( self.opts.load_checkpoint_path )) # Load the network parameters ckpt = torch.load(self.opts.load_checkpoint_path) print("[INFO] Checkpoint successfully loaded ...") self.QuestionNet.load_state_dict(ckpt['state_dict']) # Load the optimizer paramters optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr)) # , ckpt['optim'], lr_base=ckpt['lr_base']) # optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH) optim.optimizer.load_state_dict(ckpt['optimizer']) _iter = 0 # ckpt['last_iter'] epoch = 0 # ckpt['last_epoch'] else: optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr)) _iter = 0 epoch = 0 trainTime = 0 bestValAcc = float("-inf") bestCkp = 0 # Training loop while _iter < self.opts.num_iters: # Learning Rate Decay if _iter in self.opts.lr_decay_marks: adjust_lr(optim, self.opts.lr_decay_factor) # Define multi-thread dataloader dataloader = Data.DataLoader( self.datasetTr, batch_size=self.opts.batch_size, shuffle=self.opts.shuffle_data, num_workers=self.opts.num_workers, ) # Iteration time_start = 0 time_end = 0 for batch_iter, (quest, hist, prog, questionRound, _) in enumerate(dataloader): time_start = time.time() if _iter >= self.opts.num_iters: break quest = quest.to(self.device) if self.opts.last_n_rounds < 10: last_n_rounds_batch = [] for i, r in enumerate(questionRound.tolist()): startIdx = max(r - self.opts.last_n_rounds, 0) endIdx = max(r, self.opts.last_n_rounds) if hist.dim() == 3: assert endIdx - startIdx == self.opts.last_n_rounds histBatch = hist[i, :, :] last_n_rounds_batch.append(histBatch[startIdx:endIdx, :]) elif hist.dim() == 2: startIdx *= 20 endIdx *= 20 histBatch = hist[i, :] temp = histBatch[startIdx:endIdx].cpu() if r > self.opts.last_n_rounds: last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0)) else: last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0)) hist = torch.stack(last_n_rounds_batch, dim=0) hist = hist.to(self.device) prog = prog.to(self.device) progTarget = prog.clone() optim.zero_grad() predSoftmax, _ = self.QuestionNet(quest, hist, prog[:, :-1]) loss = self.loss_fn( # predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)), predSoftmax.contiguous().view(-1, predSoftmax.size(2)), progTarget[:, 1:].contiguous().view(-1)) loss.backward() if _iter % self.opts.validate_every == 0 and _iter > 0: valAcc = self.val() if valAcc > bestValAcc: bestValAcc = valAcc bestCkp = _iter print("\n[INFO] Checkpointing model @ iter {} with val accuracy {}\n".format(_iter, valAcc)) state = { 'state_dict': self.QuestionNet.state_dict(), 'optimizer': optim.optimizer.state_dict(), 'lr_base': optim.lr_base, 'optim': optim.lr_base, 'last_iter': _iter, 'last_epoch': epoch, } # checkpointing torch.save( state, os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl') ) # logging self.writer.add_scalar( 'train/loss', loss.cpu().data.numpy(), global_step=_iter) self.writer.add_scalar( 'train/lr', optim._rate, global_step=_iter) if _iter % self.opts.display_every == 0: time_end = time.time() trainTime += time_end-time_start print("\r[CLEVR-Dialog - %s (%d | %d)][epoch %2d][iter %4d/%4d][runtime %4f] loss: %.4f, lr: %.2e" % ( self.datasetTr.name, batch_iter, len(dataloader), epoch, _iter, self.opts.num_iters, trainTime, loss.cpu().data.numpy(), optim._rate, ), end=' ') optim.step() _iter += 1 epoch += 1 print("[INFO] Avg. epoch time: {} s".format(trainTime / epoch)) print("[INFO] Best model achieved val acc. {} @ iter {}".format(bestValAcc, bestCkp)) def val(self): self.QuestionNet.eval() total_correct = 0 total = 0 if len(self.opts.gpu_ids) > 1: self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids) self.QuestionNet = self.QuestionNet.eval() dataloader = Data.DataLoader( self.datasetVal, batch_size=self.opts.batch_size, shuffle=True, num_workers=self.opts.num_workers, pin_memory=False ) _iterCur = 0 _totalCur = len(dataloader) for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader): # print("\rEvaluation: [step %4d/%4d]" % ( print("\rEvaluation: [step %4d/%4d]" % ( step, int(len(dataloader)), ), end=' ') question = question.to(self.device) if history.dim() == 3: caption = history.detach() caption = caption[:, 0, :] caption = caption[:, :16].to(self.device) elif history.dim() == 2: caption = history.detach() caption = caption[:, :16].to(self.device) if self.opts.last_n_rounds is not None: last_n_rounds_batch = [] for i, r in enumerate(questionRounds.tolist()): startIdx = max(r - self.opts.last_n_rounds, 0) endIdx = max(r, self.opts.last_n_rounds) if history.dim() == 3: assert endIdx - startIdx == self.opts.last_n_rounds histBatch = history[i, :, :] last_n_rounds_batch.append(histBatch[startIdx:endIdx, :]) elif history.dim() == 2: startIdx *= 20 endIdx *= 20 histBatch = history[i, :] temp = histBatch[startIdx:endIdx] if r > self.opts.last_n_rounds: last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0)) else: last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0)) history = torch.stack(last_n_rounds_batch, dim=0) history = history.to(self.device) questionPrg = questionPrg.to(self.device) questProgsToksPred = self.QuestionNet.sample(question, history) questProgsPred = decodeProg(questProgsToksPred, self.datasetVal.vocab["idx_prog_to_token"]) targetProgs = decodeProg(questionPrg, self.datasetVal.vocab["idx_prog_to_token"], target=True) correct = [1 if pred == gt else 0 for (pred, gt) in zip(questProgsPred, targetProgs)] correct = sum(correct) total_correct += correct total += len(targetProgs) self.QuestionNet.train() return 100.0 * (total_correct / total) # Evaluation def eval_with_gt(self): # Define the multi-gpu training if needed all_pred_answers = [] all_gt_answers = [] all_question_types = [] all_penalties = [] all_pred_programs = [] all_gt_programs = [] first_failure_round = 0 total_correct = 0 total_acc_pen = 0 total = 0 total_quest_prog_correct = 0 if len(self.opts.gpu_ids) > 1: self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids) self.QuestionNet = self.QuestionNet.eval() self.CaptionNet = self.CaptionNet.eval() if self.opts.batch_size != self.opts.dialogLen: print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen)) self.opts.batch_size = self.opts.dialogLen dataloader = Data.DataLoader( self.datasetTest, batch_size=self.opts.batch_size, shuffle=False, num_workers=self.opts.num_workers, pin_memory=False ) _iterCur = 0 _totalCur = len(dataloader) for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader): # print("\rEvaluation: [step %4d/%4d]" % ( # step + 1, # int(data_size / self.opts.batch_size), # ), end=' ') # if step >= 5000: # break batchSize = question.size(0) question = question.to(self.device) # dependecy = self.dependencies[step*batchSize:(step+1)*batchSize] if history.dim() == 3: caption = history.detach() caption = caption[:, 0, :] caption = caption[:, :16].to(self.device) elif history.dim() == 2: caption = history.detach() caption = caption[:, :16].to(self.device) if self.opts.last_n_rounds < 10: last_n_rounds_batch = [] for i, r in enumerate(questionRounds.tolist()): startIdx = max(r - self.opts.last_n_rounds, 0) endIdx = max(r, self.opts.last_n_rounds) if history.dim() == 3: assert endIdx - startIdx == self.opts.last_n_rounds histBatch = history[i, :, :] last_n_rounds_batch.append(histBatch[startIdx:endIdx, :]) elif history.dim() == 2: startIdx *= 20 endIdx *= 20 histBatch = history[i, :] temp = histBatch[startIdx:endIdx] if r > self.opts.last_n_rounds: last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0)) else: last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0)) history = torch.stack(last_n_rounds_batch, dim=0) history = history.to(self.device) questionPrg = questionPrg.to(self.device) historiesProg = historiesProg.tolist() questionRounds = questionRounds.tolist() answer = answer.tolist() answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer)) questionImgIdx = questionImgIdx.tolist() # if "minecraft" in self.opts.scenesPath: # questionImgIdx = [idx - 1 for idx in questionImgIdx] questProgsToksPred = self.QuestionNet.sample(question, history) capProgsToksPred = self.CaptionNet.sample(caption) questProgsPred = decodeProg(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"]) capProgsPred = decodeProg(capProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"]) targetProgs = decodeProg(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True) questionTypes = [targetProg[0] for targetProg in targetProgs] # progHistories = getProgHistories(historiesProg[0], dataset.vocab["idx_prog_to_token"]) progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg] pred_answers = [] all_pred_programs.append([capProgsPred[0]] + questProgsPred) all_gt_programs.append([progHistories[0]] + (targetProgs)) for i in range(batchSize): # if capProgsPred[i][0] == "extreme-center": # print("bla") # print("idx = {}".format(questionImgIdx[i])) ans = self.getPrediction( questProgsPred[i], capProgsPred[i], progHistories[i], questionImgIdx[i] ) # if ans == "Error": # print(capProgsPred[i]) pred_answers.append(ans) # print(pred_answers) correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)] correct_prog = [1 if pred == ans else 0 for (pred, ans) in zip(questProgsPred, targetProgs)] idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1) if idx_false.shape[-1] > 0: first_failure_round += idx_false[0] + 1 else: first_failure_round += self.opts.dialogLen + 1 correct = sum(correct) correct_prog = sum(correct_prog) total_correct += correct total_quest_prog_correct += correct_prog total += len(answers) all_pred_answers.append(pred_answers) all_gt_answers.append(answers) all_question_types.append(questionTypes) penalty = np.zeros_like(penalty) all_penalties.append(penalty) _iterCur += 1 if _iterCur % self.opts.display_every == 0: print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format( _iterCur, _totalCur, 100.0 * (total_correct / total))) ffr = 1.0 * (first_failure_round/_totalCur)/(self.opts.dialogLen + 1) textOut = "\n --------------- Average First Failure Round --------------- \n" textOut += "{} / {}".format(ffr, self.opts.dialogLen) # print(total_correct, total) accuracy = total_correct / total vd_acc = total_acc_pen / total quest_prog_acc = total_quest_prog_correct / total textOut += "\n --------------- Overall acc. --------------- \n" textOut += "{}".format(100.0 * accuracy) textOut += "\n --------------- Overall VD acc. --------------- \n" textOut += "{}".format(100.0 * vd_acc) textOut += "\n --------------- Question Prog. Acc --------------- \n" textOut += "{}".format(100.0 * quest_prog_acc) textOut += get_per_round_acc( all_pred_answers, all_gt_answers, all_penalties) textOut += get_per_question_type_acc( all_pred_answers, all_gt_answers, all_question_types, all_penalties) # textOut += get_per_dependency_type_acc( # all_pred_answers, all_gt_answers, all_penalties) textOut += "\n --------------- Done --------------- \n" print(textOut) fname = self.opts.questionNetPath.split("/")[-3] + "results_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen) pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen) pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname) model_name = "NSVD_stack" if "stack" in self.opts.questionNetPath else "NSVD_concat" experiment_name = "minecraft" # experiment_name += "_{}".format(self.opts.dialogLen) prog_output_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/prog_output/{}_{}.pkl".format(model_name, experiment_name)) fpath = os.path.join(self.opts.text_log_dir, fname) with open(fpath, "w") as f: f.writelines(textOut) with open(pred_answers_fname, "wb") as f: pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL) with open(prog_output_fname, "wb") as f: pickle.dump((all_gt_programs, all_pred_programs, all_pred_answers), f, protocol=pickle.HIGHEST_PROTOCOL) # Evaluation def eval_with_pred(self): # Define the multi-gpu training if needed all_pred_answers = [] all_gt_answers = [] all_question_types = [] all_penalties = [] first_failure_round = 0 total_correct = 0 total_acc_pen = 0 total = 0 samples = {} if len(self.opts.gpu_ids) > 1: self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids) self.QuestionNet = self.QuestionNet.eval() self.CaptionNet = self.CaptionNet.eval() if self.opts.batch_size != self.opts.dialogLen: print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen)) self.opts.batch_size = self.opts.dialogLen dataloader = Data.DataLoader( self.datasetTest, batch_size=self.opts.batch_size, shuffle=False, num_workers=self.opts.num_workers, pin_memory=False ) _iterCur = 0 _totalCur = len(dataloader) step = 0 for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader): question = question.tolist() questions = decode(question, self.datasetTest.vocab["idx_text_to_token"], target=True) questions = list(map(lambda q: " ".join(q), questions)) targetProgs = decode(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True) questionTypes = [targetProg[0] for targetProg in targetProgs] targetProgs = list(map(lambda q: " ".join(q), targetProgs)) historiesProg = historiesProg.tolist() progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg] answer = answer.tolist() answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer)) questionImgIdx = questionImgIdx.tolist() if self.opts.encoderType == 2: histories_eval = [history[0, 0, :].tolist()] caption = history.detach() caption = caption[0, 0, :].unsqueeze(0) caption = caption[:, :16].to(self.device) elif self.opts.encoderType == 1: caption = history.detach() histories_eval = [history[0, :20].tolist()] caption = caption[0, :16].unsqueeze(0).to(self.device) cap = decode(caption, self.datasetTest.vocab["idx_text_to_token"], target=False) capProgToksPred = self.CaptionNet.sample(caption) capProgPred = decode(capProgToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0] pred_answers = [] pred_quest_prog = [] for i, (q, prog_hist, img_idx) in enumerate(zip(question, progHistories, questionImgIdx)): _round = i + 1 if _round <= self.opts.last_n_rounds: start = 0 else: start = _round - self.opts.last_n_rounds end = len(histories_eval) quest = torch.tensor(q).unsqueeze(0).to(self.device) if self.opts.encoderType == 3: hist = torch.stack([torch.tensor(h) for h in histories_eval[start:end]], dim=0).unsqueeze(0).to(self.device) elif self.opts.encoderType == 1: histories_eval_copy = deepcopy(histories_eval) histories_eval_copy[-1].append(self.datasetTest.vocab["text_token_to_idx"][""]) hist = torch.cat([torch.tensor(h) for h in histories_eval_copy[start:end]], dim=-1).unsqueeze(0).to(self.device) questProgsToksPred = self.QuestionNet.sample(quest, hist) questProgsPred = decode(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0] pred_quest_prog.append(" ".join(questProgsPred)) ans = self.getPrediction( questProgsPred, capProgPred, prog_hist, img_idx ) ans_idx = self.datasetTest.vocab["text_token_to_idx"].get( ans, self.datasetTest.vocab["text_token_to_idx"][""]) q[q.index(self.datasetTest.vocab["text_token_to_idx"][""])] = self.datasetTest.vocab["text_token_to_idx"][""] q[-1] = self.datasetTest.vocab["text_token_to_idx"][""] q.insert(-1, ans_idx) if self.opts.encoderType == 3: histories_eval.append(copy.deepcopy(q)) elif self.opts.encoderType == 0: del q[0] del q[-1] histories_eval.append(copy.deepcopy(q)) pred_answers.append(ans) correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)] idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1) if idx_false.shape[-1] > 0: first_failure_round += idx_false[0] + 1 else: first_failure_round += self.opts.dialogLen + 1 correct = sum(correct) total_correct += correct total += len(answers) all_pred_answers.append(pred_answers) all_gt_answers.append(answers) all_question_types.append(questionTypes) _iterCur += 1 if _iterCur % self.opts.display_every == 0: print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format( _iterCur, _totalCur, 100.0 * (total_correct / total) )) samples["{}_{}".format(questionImgIdx[0], (step % 5) + 1)] = { "caption": " ".join(cap[0]), "cap_prog_gt": " ".join(progHistories[0][0]), "cap_prog_pred": " ".join(capProgPred), "questions": questions, "quest_progs_gt": targetProgs, "quest_progs_pred": pred_quest_prog, "answers": answers, "preds": pred_answers, "acc": correct, } ffr = 1.0 * self.opts.dialogLen * (first_failure_round/total) textOut = "\n --------------- Average First Failure Round --------------- \n" textOut += "{} / {}".format(ffr, self.opts.dialogLen) # print(total_correct, total) accuracy = total_correct / total vd_acc = total_acc_pen / total textOut += "\n --------------- Overall acc. --------------- \n" textOut += "{}".format(100.0 * accuracy) textOut += "\n --------------- Overall VD acc. --------------- \n" textOut += "{}".format(100.0 * vd_acc) textOut += get_per_round_acc( all_pred_answers, all_gt_answers, all_penalties) textOut += get_per_question_type_acc( all_pred_answers, all_gt_answers, all_question_types, all_penalties) textOut += "\n --------------- Done --------------- \n" print(textOut) if step >= len(dataloader): fname = self.opts.questionNetPath.split("/")[-3] + "_results_{}_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen, self.acc_type) pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen) pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname) fpath = os.path.join(self.opts.text_log_dir, fname) with open(fpath, "w") as f: f.writelines(textOut) with open(pred_answers_fname, "wb") as f: pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL) def getPrediction(self, questProgPred, capProgPred, historyProg, imgIndex): self.symbolicExecutor.reset(imgIndex) # if round one, execute the predicted caption program first then answer the question if len(historyProg) == 1: captionFuncLabel = capProgPred[0] captionFuncArgs = capProgPred[1:] questionFuncLabel = questProgPred[0] questionFuncArgs = questProgPred[1:] try: _ = self.symbolicExecutor.execute(captionFuncLabel, captionFuncArgs) except: return "Error" try: predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs) except: return "Error" # If it is not the first round, we have to execute the program history and # then answer the question. else: questionFuncLabel = questProgPred[0] questionFuncArgs = questProgPred[1:] for prg in historyProg: # prg = prg.split(" ") FuncLabel = prg[0] FuncArgs = prg[1:] try: _ = self.symbolicExecutor.execute(FuncLabel, FuncArgs) except: return "Error" try: predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs) except: return "Error" return str(predAnswer) def run(self, run_mode, epoch=None): self.set_seed(self.opts.seed) if run_mode == 'train': self.train() elif run_mode == 'test_with_gt': print('Testing with gt answers in history') print('Loading ckpt {}'.format(self.opts.questionNetPath)) state_dict = torch.load(self.opts.questionNetPath)['state_dict'] self.QuestionNet.load_state_dict(state_dict) self.eval_with_gt() elif run_mode == 'test_with_pred': print('Testing with predicted answers in history') print('Loading ckpt {}'.format(self.opts.questionNetPath)) state_dict = torch.load(self.opts.questionNetPath)['state_dict'] self.QuestionNet.load_state_dict(state_dict) self.eval_with_pred() else: exit(-1) def set_seed(self, seed): """Sets the seed for reproducibility. Args: seed (int): The seed used """ torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(seed) print('[INFO] Seed set to {}...'.format(seed)) def constructQuestionNet(opts, lenVocabText, lenVocabProg, maxLenProg): decoder = Decoder(opts, lenVocabProg, maxLenProg) if opts.encoderType == 1: encoder = QuestEncoder_1(opts, lenVocabText) elif opts.encoderType == 2: encoder = QuestEncoder_2(opts, lenVocabText) net = SeqToSeqQ(encoder, decoder) return net def constructCaptionNet(opts, lenVocabText, lenVocabProg, maxLenProg): decoder = Decoder(opts, lenVocabProg, maxLenProg) encoder = CaptionEncoder(opts, lenVocabText) net = SeqToSeqC(encoder, decoder) return net def getProgHistories(progHistToks, prgIdxToToken): progHist = [] temp = [] for tok in progHistToks: if tok not in [0, 1, 2]: temp.append(prgIdxToToken[tok]) # del progHistToks[i] if tok == 2: # del progHistToks[i] # progHist.append(" ".join(temp)) progHist.append(temp) temp = [] return progHist def getHistoriesFromStack(histToks, textIdxToToken): histories = "\n" temp = [] for i, roundToks in enumerate(histToks): for tok in roundToks: if tok not in [0, 1, 2]: temp.append(textIdxToToken[tok]) # del progHistToks[i] if tok == 2: # del progHistToks[i] if i == 0: histories += " ".join(temp) + ".\n" else: histories += " ".join(temp[:-1]) + "? | {}.\n".format(temp[-1]) # histories.append(temp) temp = [] break return histories def getHistoriesFromConcat(histToks, textIdxToToken): histories = [] temp = [] for tok in histToks: if tok not in [0, 1, 2]: temp.append(textIdxToToken[tok]) # del progHistToks[i] if tok == 2: # del progHistToks[i] histories.append(" ".join(temp[:-1]) + "? | {}".format(temp[-1])) # histories.append(temp) temp = [] return histories def decodeProg(tokens, prgIdxToToken, target=False): tokensBatch = tokens.tolist() progsBatch = [] for tokens in tokensBatch: prog = [] for tok in tokens: if tok == 2: # has index 2 break prog.append(prgIdxToToken.get(tok)) if target: prog = prog[1:] # progsBatch.append(" ".join(prog)) progsBatch.append(prog) return progsBatch def printPred(predSoftmax, gts, prgIdxToToken): assert predSoftmax.size(0) == gts.size(0) tokens = predSoftmax.topk(1)[1].squeeze(-1) tokens = tokens.tolist() gts = gts.tolist() message = "\n ------------------------ \n" for token, gt in zip(tokens, gts): message += "Prediction: " for tok in token: message += prgIdxToToken.get(tok) + " " message += "\n Target : " for tok in gt: message += prgIdxToToken.get(tok) + " " message += "\n ------------------------ \n" return message def get_per_round_acc(preds, gts, penalties): res = {} for img_preds, img_gt, img_pen in zip(preds, gts, penalties): img_preds = list(img_preds) img_gt = list(img_gt) img_pen = list(img_pen) for i, (pred, gt, pen) in enumerate(zip(img_preds, img_gt, img_pen)): _round = str(i + 1) if _round not in res: res[_round] = { "correct": 0, "all": 0 } res[_round]["all"] += 1 if pred == gt: res[_round]["correct"] += 0.5**pen textOut = "\n --------------- Per round Acc --------------- \n" for k in res: textOut += "{}: {} %\n".format(k, 100.0 * (res[k]["correct"]/res[k]["all"])) return textOut def get_per_question_type_acc(preds, gts, qtypes, penalties): res1 = {} res2 = {} for img_preds, img_gt, img_qtypes, img_pen in zip(preds, gts, qtypes, penalties): # img_preds = list(img_preds) # img_gt = list(img_gt) img_pen = list(img_pen) for pred, gt, temp, pen in zip(img_preds, img_gt, img_qtypes, img_pen): if temp not in res1: res1[temp] = { "correct": 0, "all": 0 } temp_cat = temp.split("-")[0] if temp_cat not in res2: res2[temp_cat] = { "correct": 0, "all": 0 } res1[temp]["all"] += 1 res2[temp_cat]["all"] += 1 if pred == gt: res1[temp]["correct"] += 0.5**pen res2[temp_cat]["correct"] += 0.5**pen textOut = "\n --------------- Per question Type Acc --------------- \n" for k in res1: textOut += "{}: {} %\n".format(k, 100.0 * (res1[k]["correct"]/res1[k]["all"])) textOut += "\n --------------- Per question Category Acc --------------- \n" for k in res2: textOut += "{}: {} %\n".format(k, 100.0 * (res2[k]["correct"]/res2[k]["all"])) return textOut def decode(tokens, prgIdxToToken, target=False): if type(tokens) != list: tokens = tokens.tolist() progsBatch = [] for token in tokens: prog = [] for tok in token: if tok == 2: # has index 2 break prog.append(prgIdxToToken.get(tok)) if target: prog = prog[1:] # progsBatch.append(" ".join(prog)) progsBatch.append(prog) return progsBatch if __name__ == "__main__": optsC = OptionsC().parse() optsQ = OptionsQ().parse() exe = Execution(optsQ, optsC) exe.run("test") print("[INFO] Done ...")