make code public
This commit is contained in:
commit
9d8b93db26
26 changed files with 11937 additions and 0 deletions
94
prog_generator/clevrDialog_dataset.py
Normal file
94
prog_generator/clevrDialog_dataset.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def invertDict(_dict):
|
||||
return {v: k for k, v in _dict.items()}
|
||||
|
||||
|
||||
class ClevrDialogDataset(Dataset):
|
||||
def __init__(self, dataPath, vocabPath, split, indStart=0, indEnd=-1):
|
||||
super(ClevrDialogDataset, self).__init__()
|
||||
self.data = h5py.File(dataPath, "r")
|
||||
with open(vocabPath, "r") as f:
|
||||
self.vocab = json.load(f)
|
||||
self.vocab["idx_text_to_token"] = invertDict(self.vocab["text_token_to_idx"])
|
||||
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
||||
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
||||
self.lenVocabText = len(self.vocab["text_token_to_idx"])
|
||||
self.lenVocabProg = len(self.vocab["prog_token_to_idx"])
|
||||
|
||||
self.split = split
|
||||
self.indStart = indStart
|
||||
self.indEnd = indEnd
|
||||
self.maxSamples = indEnd - indStart
|
||||
self.maxLenProg = 6
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClevrDialogCaptionDataset(ClevrDialogDataset):
|
||||
def __init__(self, dataPath, vocabPath, split, name, indStart=0, indEnd=-1):
|
||||
super(ClevrDialogCaptionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
||||
self.captions = torch.LongTensor(np.asarray(self.data["captions"], dtype=np.int64)[indStart: indEnd])
|
||||
self.captionsPrgs = torch.LongTensor(np.asarray(self.data["captionProgs"], dtype=np.int64)[indStart: indEnd])
|
||||
self.name = name
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert idx < len(self)
|
||||
caption = self.captions[idx][:16]
|
||||
captionPrg = self.captionsPrgs[idx]
|
||||
return caption, captionPrg
|
||||
|
||||
|
||||
class ClevrDialogQuestionDataset(ClevrDialogDataset):
|
||||
def __init__(self, dataPath, vocabPath, split, name, train=True, indStart=0, indEnd=-1):
|
||||
super(ClevrDialogQuestionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
||||
self.questions = torch.LongTensor(np.asarray(self.data["questions"], dtype=np.int64)[indStart: indEnd])
|
||||
self.quesProgs = torch.LongTensor(np.asarray(self.data["questionProgs"], dtype=np.int64)[indStart: indEnd])
|
||||
self.questionRounds = torch.LongTensor(np.asarray(self.data["questionRounds"], dtype=np.int64)[indStart: indEnd])
|
||||
self.questionImgIdx = torch.LongTensor(np.asarray(self.data["questionImgIdx"], dtype=np.int64)[indStart: indEnd])
|
||||
self.histories = torch.LongTensor(np.asarray(self.data["histories"], dtype=np.int64)[indStart: indEnd])
|
||||
self.historiesProgs = torch.LongTensor(np.asarray(self.data["historiesProg"], dtype=np.int64)[indStart: indEnd])
|
||||
|
||||
self.answers = torch.LongTensor(np.asarray(self.data["answers"], dtype=np.int64)[indStart: indEnd])
|
||||
self.name = name
|
||||
self.train = train
|
||||
|
||||
def __len__(self):
|
||||
return len(self.questions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert idx < len(self)
|
||||
question = self.questions[idx]
|
||||
questionPrg = self.quesProgs[idx]
|
||||
questionImgIdx = self.questionImgIdx[idx]
|
||||
questionRound = self.questionRounds[idx]
|
||||
|
||||
history = self.histories[idx]
|
||||
historiesProg = self.historiesProgs[idx]
|
||||
|
||||
answer = self.answers[idx]
|
||||
if self.train:
|
||||
return question, history, questionPrg, questionRound, answer
|
||||
else:
|
||||
return question, questionPrg, questionImgIdx, questionRound, history, historiesProg, answer
|
476
prog_generator/models.py
Normal file
476
prog_generator/models.py
Normal file
|
@ -0,0 +1,476 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
|
||||
super(FC, self).__init__()
|
||||
self.dropout_r = dropout_r
|
||||
self.use_relu = use_relu
|
||||
|
||||
self.linear = nn.Linear(in_size, out_size)
|
||||
|
||||
if use_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if dropout_r > 0:
|
||||
self.dropout = nn.Dropout(dropout_r)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
||||
if self.use_relu:
|
||||
x = self.relu(x)
|
||||
|
||||
if self.dropout_r > 0:
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
||||
self.linear = nn.Linear(mid_size, out_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.fc(x))
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, size, eps=1e-6):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
self.a_2 = nn.Parameter(torch.ones(size))
|
||||
self.b_2 = nn.Parameter(torch.zeros(size))
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = x.std(-1, keepdim=True)
|
||||
|
||||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||
|
||||
|
||||
class MHAtt(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(MHAtt, self).__init__()
|
||||
self.opts = opts
|
||||
|
||||
self.linear_v = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_k = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_q = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_merge = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
self.dropout = nn.Dropout(opts.dropout)
|
||||
|
||||
def forward(self, v, k, q, mask):
|
||||
n_batches = q.size(0)
|
||||
|
||||
v = self.linear_v(v).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
k = self.linear_k(k).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
q = self.linear_q(q).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
atted = self.att(v, k, q, mask)
|
||||
atted = atted.transpose(1, 2).contiguous().view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.hiddenDim
|
||||
)
|
||||
|
||||
atted = self.linear_merge(atted)
|
||||
|
||||
return atted
|
||||
|
||||
def att(self, value, key, query, mask):
|
||||
d_k = query.size(-1)
|
||||
|
||||
scores = torch.matmul(
|
||||
query, key.transpose(-2, -1)
|
||||
) / math.sqrt(d_k)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask, -1e9)
|
||||
|
||||
att_map = F.softmax(scores, dim=-1)
|
||||
att_map = self.dropout(att_map)
|
||||
|
||||
return torch.matmul(att_map, value)
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(FFN, self).__init__()
|
||||
|
||||
self.mlp = MLP(
|
||||
in_size=opts.hiddenDim,
|
||||
mid_size=opts.FeedForwardSize,
|
||||
out_size=opts.hiddenDim,
|
||||
dropout_r=opts.dropout,
|
||||
use_relu=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class SA(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(SA, self).__init__()
|
||||
self.mhatt = MHAtt(opts)
|
||||
self.ffn = FFN(opts)
|
||||
|
||||
self.dropout1 = nn.Dropout(opts.dropout)
|
||||
self.norm1 = LayerNorm(opts.hiddenDim)
|
||||
|
||||
self.dropout2 = nn.Dropout(opts.dropout)
|
||||
self.norm2 = LayerNorm(opts.hiddenDim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.norm1(x + self.dropout1(
|
||||
self.mhatt(x, x, x, x_mask)
|
||||
))
|
||||
|
||||
x = self.norm2(x + self.dropout2(
|
||||
self.ffn(x)
|
||||
))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AttFlat(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(AttFlat, self).__init__()
|
||||
self.opts = opts
|
||||
|
||||
self.mlp = MLP(
|
||||
in_size=opts.hiddenDim,
|
||||
mid_size=opts.FlatMLPSize,
|
||||
out_size=opts.FlatGlimpses,
|
||||
dropout_r=opts.dropout,
|
||||
use_relu=True
|
||||
)
|
||||
# FLAT_GLIMPSES = 1
|
||||
self.linear_merge = nn.Linear(
|
||||
opts.hiddenDim * opts.FlatGlimpses,
|
||||
opts.FlatOutSize
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
att = self.mlp(x)
|
||||
att = att.masked_fill(
|
||||
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
||||
-1e9
|
||||
)
|
||||
att = F.softmax(att, dim=1)
|
||||
|
||||
att_list = []
|
||||
for i in range(self.opts.FlatGlimpses):
|
||||
att_list.append(
|
||||
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
||||
)
|
||||
|
||||
x_atted = torch.cat(att_list, dim=1)
|
||||
x_atted = self.linear_merge(x_atted)
|
||||
|
||||
return x_atted
|
||||
|
||||
class CaptionEncoder(nn.Module):
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(CaptionEncoder, self).__init__()
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
bidirectional = opts.bidirectional > 0
|
||||
self.lstmC = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional
|
||||
)
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
opts.hiddenSizeHead *= 2
|
||||
opts.FlatOutSize *= 2
|
||||
|
||||
self.attCap = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
self.attFlatCap = AttFlat(opts)
|
||||
self.fc = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, cap, hist=None):
|
||||
capMask = self.make_mask(cap.unsqueeze(2))
|
||||
cap = self.embedding(cap)
|
||||
cap, (_, _) = self.lstmC(cap)
|
||||
capO = cap.detach().clone()
|
||||
|
||||
for attC in self.attCap:
|
||||
cap = attC(cap, capMask)
|
||||
# (batchSize, 512)
|
||||
cap = self.attFlatCap(cap, capMask)
|
||||
encOut = self.fc(cap)
|
||||
return encOut, capO
|
||||
|
||||
class QuestEncoder_1(nn.Module):
|
||||
"""
|
||||
Concat encoder
|
||||
"""
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(QuestEncoder_1, self).__init__()
|
||||
bidirectional = opts.bidirectional > 0
|
||||
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
self.lstmQ = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
self.lstmH = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=True)
|
||||
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
opts.hiddenSizeHead *= 2
|
||||
opts.FlatOutSize *= 2
|
||||
self.attQues = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
self.attHist = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
|
||||
self.attFlatQuest = AttFlat(opts)
|
||||
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, quest, hist):
|
||||
questMask = self.make_mask(quest.unsqueeze(2))
|
||||
histMask = self.make_mask(hist.unsqueeze(2))
|
||||
|
||||
# quest = F.tanh(self.embedding(quest))
|
||||
quest = self.embedding(quest)
|
||||
|
||||
quest, (_, _) = self.lstmQ(quest)
|
||||
questO = quest.detach().clone()
|
||||
|
||||
hist = self.embedding(hist)
|
||||
hist, (_, _) = self.lstmH(hist)
|
||||
|
||||
for attQ, attH in zip(self.attQues, self.attHist):
|
||||
quest = attQ(quest, questMask)
|
||||
hist = attH(hist, histMask)
|
||||
# (batchSize, 512)
|
||||
quest = self.attFlatQuest(quest, questMask)
|
||||
|
||||
# hist: (batchSize, length, 512)
|
||||
attWeights = torch.sum(torch.mul(hist, quest.unsqueeze(1)), -1)
|
||||
attWeights = torch.softmax(attWeights, -1)
|
||||
hist = torch.sum(torch.mul(hist, attWeights.unsqueeze(2)), 1)
|
||||
encOut = self.fc(torch.cat([quest, hist], -1))
|
||||
|
||||
return encOut, questO
|
||||
|
||||
# Masking
|
||||
def make_mask(self, feature):
|
||||
return (torch.sum(
|
||||
torch.abs(feature),
|
||||
dim=-1
|
||||
) == 0).unsqueeze(1).unsqueeze(2)
|
||||
|
||||
|
||||
class QuestEncoder_2(nn.Module):
|
||||
"""
|
||||
Stack encoder
|
||||
"""
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(QuestEncoder_2, self).__init__()
|
||||
bidirectional = opts.bidirectional > 0
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
self.lstmQ = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
self.lstmH = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
|
||||
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, quest, hist):
|
||||
|
||||
quest = F.tanh(self.embedding(quest))
|
||||
quest, (questH, _) = self.lstmQ(quest)
|
||||
|
||||
# concatenate the last hidden states from the forward and backward pass
|
||||
# of the bidirectional lstm
|
||||
lastHiddenForward = questH[1:2, :, :].squeeze(0)
|
||||
lastHiddenBackward = questH[3:4, :, :].squeeze(0)
|
||||
|
||||
# questH: (batchSize, 512)
|
||||
questH = torch.cat([lastHiddenForward, lastHiddenBackward], -1)
|
||||
|
||||
questO = quest.detach().clone()
|
||||
|
||||
hist = F.tanh(self.embedding(hist))
|
||||
numRounds = hist.size(1)
|
||||
histFeat = []
|
||||
for i in range(numRounds):
|
||||
round_i = hist[:, i, :, :]
|
||||
_, (round_i_h, _) = self.lstmH(round_i)
|
||||
|
||||
#Same as before
|
||||
lastHiddenForward = round_i_h[1:2, :, :].squeeze(0)
|
||||
lastHiddenBackward = round_i_h[3:4, :, :].squeeze(0)
|
||||
histFeat.append(torch.cat([lastHiddenForward, lastHiddenBackward], -1))
|
||||
|
||||
# hist: (batchSize, rounds, 512)
|
||||
histFeat = torch.stack(histFeat, 1)
|
||||
attWeights = torch.sum(torch.mul(histFeat, questH.unsqueeze(1)), -1)
|
||||
attWeights = torch.softmax(attWeights, -1)
|
||||
histFeat = torch.sum(torch.mul(histFeat, attWeights.unsqueeze(2)), 1)
|
||||
encOut = self.fc(torch.cat([questH, histFeat], -1))
|
||||
return encOut, questO
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, opts, progVocabSize, maxLen, startID=1, endID=2):
|
||||
super(Decoder, self).__init__()
|
||||
self.numLayers = opts.numLayers
|
||||
self.bidirectional = opts.bidirectional > 0
|
||||
self.maxLen = maxLen
|
||||
self.startID = startID
|
||||
self.endID = endID
|
||||
|
||||
self.embedding = nn.Embedding(progVocabSize, opts.embedDim)
|
||||
self.lstmProg = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=2*opts.hiddenDim if self.bidirectional else opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
# bidirectional=self.bidirectional,
|
||||
)
|
||||
hiddenDim = opts.hiddenDim
|
||||
if self.bidirectional:
|
||||
hiddenDim *= 2
|
||||
|
||||
self.fcAtt = nn.Linear(2*hiddenDim, hiddenDim)
|
||||
self.fcOut = nn.Linear(hiddenDim, progVocabSize)
|
||||
|
||||
def initPrgHidden(self, encOut):
|
||||
hidden = [encOut for _ in range(self.numLayers)]
|
||||
hidden = torch.stack(hidden, 0).contiguous()
|
||||
return hidden, hidden
|
||||
|
||||
def forwardStep(self, prog, progH, questO):
|
||||
batchSize = prog.size(0)
|
||||
inputDim = questO.size(1)
|
||||
prog = self.embedding(prog)
|
||||
outProg, progH = self.lstmProg(prog, progH)
|
||||
|
||||
att = torch.bmm(outProg, questO.transpose(1, 2))
|
||||
att = F.softmax(att.view(-1, inputDim), 1).view(batchSize, -1, inputDim)
|
||||
context = torch.bmm(att, questO)
|
||||
# (batchSize, progLength, hiddenDim)
|
||||
out = F.tanh(self.fcAtt(torch.cat([outProg, context], dim=-1)))
|
||||
|
||||
# (batchSize, progLength, progVocabSize)
|
||||
out = self.fcOut(out)
|
||||
predSoftmax = F.log_softmax(out, 2)
|
||||
return predSoftmax, progH
|
||||
|
||||
def forward(self, prog, encOut, questO):
|
||||
progH = self.initPrgHidden(encOut)
|
||||
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||
|
||||
return predSoftmax, progH
|
||||
|
||||
def sample(self, encOut, questO):
|
||||
batchSize = encOut.size(0)
|
||||
cudaFlag = encOut.is_cuda
|
||||
progH = self.initPrgHidden(encOut)
|
||||
# prog = progCopy[:, 0:3]
|
||||
prog = torch.LongTensor(batchSize, 1).fill_(self.startID)
|
||||
# prog = torch.cat((progStart, progEnd), -1)
|
||||
if cudaFlag:
|
||||
prog = prog.cuda()
|
||||
outputLogProbs = []
|
||||
outputTokens = []
|
||||
|
||||
def decode(i, output):
|
||||
tokens = output.topk(1, dim=-1)[1].view(batchSize, -1)
|
||||
return tokens
|
||||
|
||||
for i in range(self.maxLen):
|
||||
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||
prog = decode(i, predSoftmax)
|
||||
|
||||
return outputTokens, outputLogProbs
|
||||
|
||||
|
||||
class SeqToSeqC(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(SeqToSeqC, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, cap, imgFeat, prog):
|
||||
encOut, capO = self.encoder(cap, imgFeat)
|
||||
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
||||
return predSoftmax, progHC
|
||||
|
||||
|
||||
class SeqToSeqQ(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(SeqToSeqQ, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, quest, hist, prog):
|
||||
encOut, questO = self.encoder(quest, hist)
|
||||
predSoftmax, progHC = self.decoder(prog, encOut, questO)
|
||||
return predSoftmax, progHC
|
||||
|
||||
def sample(self, quest, hist):
|
||||
with torch.no_grad():
|
||||
encOut, questO = self.encoder(quest, hist)
|
||||
outputTokens, outputLogProbs = self.decoder.sample(encOut, questO)
|
||||
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
|
||||
return outputTokens
|
79
prog_generator/optim.py
Normal file
79
prog_generator/optim.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# adapted from https://github.com/MILVLG/mcan-vqa/blob/master/core/model/optim.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.optim as Optim
|
||||
|
||||
|
||||
class WarmupOptimizer(object):
|
||||
def __init__(self, lr_base, optimizer, data_size, batch_size):
|
||||
self.optimizer = optimizer
|
||||
self._step = 0
|
||||
self.lr_base = lr_base
|
||||
self._rate = 0
|
||||
self.data_size = data_size
|
||||
self.batch_size = batch_size
|
||||
|
||||
def step(self):
|
||||
self._step += 1
|
||||
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p['lr'] = rate
|
||||
self._rate = rate
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def rate(self, step=None):
|
||||
if step is None:
|
||||
step = self._step
|
||||
|
||||
if step <= int(self.data_size / self.batch_size * 1):
|
||||
r = self.lr_base * 1/2.
|
||||
else:
|
||||
r = self.lr_base
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def get_optim(opts, model, data_size, lr_base=None):
|
||||
if lr_base is None:
|
||||
lr_base = opts.lr
|
||||
|
||||
if opts.optim == 'adam':
|
||||
optim = Optim.Adam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=0,
|
||||
betas=opts.betas,
|
||||
eps=opts.eps,
|
||||
|
||||
)
|
||||
elif opts.optim == 'rmsprop':
|
||||
optim = Optim.RMSprop(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=0,
|
||||
eps=opts.eps,
|
||||
weight_decay=opts.weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError('{} optimizer is not supported'.fromat(opts.optim))
|
||||
return WarmupOptimizer(
|
||||
lr_base,
|
||||
optim,
|
||||
data_size,
|
||||
opts.batch_size
|
||||
)
|
||||
|
||||
def adjust_lr(optim, decay_r):
|
||||
optim.lr_base *= decay_r
|
283
prog_generator/options_caption_parser.py
Normal file
283
prog_generator/options_caption_parser.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
|
||||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
# --------------------------------------------------------
|
||||
# adapted from https://github.com/kexinyi/ns-vqa/blob/master/scene_parse/attr_net/options.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import utils
|
||||
import torch
|
||||
|
||||
|
||||
class Options():
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self):
|
||||
self.parser.add_argument(
|
||||
'--mode',
|
||||
required=True,
|
||||
type=str,
|
||||
choices=['train', 'test'],
|
||||
help='The mode of the experiment')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--run_dir',
|
||||
required=True,
|
||||
type=str,
|
||||
help='The experiment directory')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--load_checkpoint_path',
|
||||
default=None,
|
||||
type=str,
|
||||
help='The path the the pretrained CaptionNet')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--res_path',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path where to log the predicted caption programs')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--gpu_ids',
|
||||
default='0',
|
||||
type=str,
|
||||
help='Id of the gpu to be used')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--seed',
|
||||
default=42,
|
||||
type=int,
|
||||
help='The seed used in training')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathTr',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed training data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathVal',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed validation data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathTest',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed test data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--vocabPath',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the generated vocabulary')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--batch_size',
|
||||
default=64,
|
||||
type=int,
|
||||
help='Batch size')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='Number of workers for loading')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--num_iters',
|
||||
default=5000,
|
||||
type=int,
|
||||
help='Total number of iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--display_every',
|
||||
default=5,
|
||||
type=int,
|
||||
help='Display training information every N iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--debug_every',
|
||||
default=100,
|
||||
type=int,
|
||||
help='Display debug message every N iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--validate_every',
|
||||
default=1000,
|
||||
type=int,
|
||||
help='Validate every N iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--shuffle_data',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Activate to shuffle the training data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--optim',
|
||||
default='adam',
|
||||
type=str,
|
||||
help='The name of the optimizer to be used')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr',
|
||||
default=1e-3,
|
||||
type=float,
|
||||
help='Base learning rate')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--betas',
|
||||
default='0.9, 0.98',
|
||||
type=str,
|
||||
help='Adam optimizer\'s betas')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--eps',
|
||||
default='1e-9',
|
||||
type=float,
|
||||
help='Adam optimizer\'s epsilon')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr_decay_marks',
|
||||
default='50000, 55000',
|
||||
type=str,
|
||||
help='Learing rate decay marks')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr_decay_factor',
|
||||
default=0.5,
|
||||
type=float,
|
||||
help='Learning rate decay factor')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--weight_decay',
|
||||
default=1e-6,
|
||||
type=float,
|
||||
help='Weight decay')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--embedDim',
|
||||
default=300,
|
||||
type=int,
|
||||
help='Embedding dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--hiddenDim',
|
||||
default=512,
|
||||
type=int,
|
||||
help='LSTM hidden dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--numLayers',
|
||||
default=2,
|
||||
type=int,
|
||||
help='Number of hidden LSTM layers')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dropout',
|
||||
default=0.1,
|
||||
type=float,
|
||||
help='Dropout value')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--multiHead',
|
||||
default=8,
|
||||
type=int,
|
||||
help='Number of attention heads')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--hiddenSizeHead',
|
||||
default=64,
|
||||
type=int,
|
||||
help='Dimension of each attention head')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FeedForwardSize',
|
||||
default=2048,
|
||||
type=int,
|
||||
help='Dimension of the feed forward layer')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatMLPSize',
|
||||
default=512,
|
||||
type=int,
|
||||
help='MLP flatten size')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatGlimpses',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Number of flatten glimpses')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatOutSize',
|
||||
default=512,
|
||||
type=int,
|
||||
help='Final attention reduction dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--layers',
|
||||
default=6,
|
||||
type=int,
|
||||
help='Number of self attention layers')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--bidirectional',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Activate to use bidirectional LSTMs')
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def parse(self):
|
||||
# initialize parser
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
self.opts = self.parser.parse_args()
|
||||
|
||||
# parse gpu id list
|
||||
str_gpu_ids = self.opts.gpu_ids.split(',')
|
||||
self.opts.gpu_ids = []
|
||||
for str_id in str_gpu_ids:
|
||||
if str_id.isdigit() and int(str_id) >= 0:
|
||||
self.opts.gpu_ids.append(int(str_id))
|
||||
if len(self.opts.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
print('\n[INFO] Using {} CUDA device(s) ...'.format(len(self.opts.gpu_ids)))
|
||||
else:
|
||||
print('\n[INFO] Using cpu ...')
|
||||
self.opts.gpu_ids = []
|
||||
|
||||
# parse the optimizer's betas and lr decay marks
|
||||
self.opts.betas = [float(beta) for beta in self.opts.betas.split(',')]
|
||||
lr_decay_marks = [int(m) for m in self.opts.lr_decay_marks.split(',')]
|
||||
for i in range(1, len(lr_decay_marks)):
|
||||
assert lr_decay_marks[i] > lr_decay_marks[i-1]
|
||||
self.opts.lr_decay_marks = lr_decay_marks
|
||||
|
||||
# print and save options
|
||||
args = vars(self.opts)
|
||||
print('\n ' + 30*'-' + 'Opts' + 30*'-')
|
||||
for k, v in args.items():
|
||||
print('%s: %s' % (str(k), str(v)))
|
||||
|
||||
if not os.path.isdir(self.opts.run_dir):
|
||||
os.makedirs(self.opts.run_dir)
|
||||
filename = 'opts.txt'
|
||||
file_path = os.path.join(self.opts.run_dir, filename)
|
||||
with open(file_path, 'wt') as fout:
|
||||
fout.write('| options\n')
|
||||
for k, v in sorted(args.items()):
|
||||
fout.write('%s: %s\n' % (str(k), str(v)))
|
||||
return self.opts
|
326
prog_generator/options_question_parser.py
Normal file
326
prog_generator/options_question_parser.py
Normal file
|
@ -0,0 +1,326 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
# --------------------------------------------------------
|
||||
# adapted from https://github.com/kexinyi/ns-vqa/blob/master/scene_parse/attr_net/options.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import utils
|
||||
import torch
|
||||
|
||||
|
||||
class Options():
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self):
|
||||
self.parser.add_argument(
|
||||
'--mode',
|
||||
required=True,
|
||||
type=str,
|
||||
choices=['train', 'test_with_gt', 'test_with_pred'],
|
||||
help='The mode of the experiment')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--run_dir',
|
||||
required=True,
|
||||
type=str,
|
||||
help='The experiment directory')
|
||||
|
||||
# self.parser.add_argument('--dataset', default='clevr', type=str, help='dataset')
|
||||
self.parser.add_argument(
|
||||
'--text_log_dir',
|
||||
required=True,
|
||||
type=str,
|
||||
help='File to save the logged text')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--questionNetPath',
|
||||
default='',
|
||||
type=str,
|
||||
help='Path to the pretrained QuestionNet that will be used for testing.')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--captionNetPath',
|
||||
default='',
|
||||
type=str,
|
||||
help='Path to the pretrained CaptionNet that will be used for testing.')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dialogLen',
|
||||
default=10,
|
||||
type=int,
|
||||
help='Length of the dialogs to be used for testing. We used 10, 15, and 20 in our experiments.')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--last_n_rounds',
|
||||
default=10,
|
||||
type=int,
|
||||
help='Number of the last rounds to consider in the history. We used 1, 2, 3, 4, and 10 in our experiments. ')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--encoderType',
|
||||
required=True,
|
||||
type=int,
|
||||
choices=[1, 2],
|
||||
help='Type of the encoder: 1 --> Concat, 2 --> Stack')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--load_checkpoint_path',
|
||||
default='None',
|
||||
type=str,
|
||||
help='Path to a QestionNet checkpoint path to resume training')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--gpu_ids',
|
||||
default='0',
|
||||
type=str,
|
||||
help='Id of the gpu to be used')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--seed',
|
||||
default=42,
|
||||
type=int,
|
||||
help='The seed used in training')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathTr',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed training data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathVal',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed validation data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dataPathTest',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the h5 file of the Clevr-Dialog preprocessed test data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--scenesPath',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the derendered clevr-dialog scenes')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--vocabPath',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Path to the generated vocabulary')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--batch_size',
|
||||
default=64,
|
||||
type=int,
|
||||
help='Batch size')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--countFirstFailueRound',
|
||||
default=0,
|
||||
type=int,
|
||||
help='If activated, we count the first failure round')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--maxSamples',
|
||||
default=-1,
|
||||
type=int,
|
||||
help='Maximum number of training samples')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='Number of workers for loading')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--num_iters',
|
||||
default=5000,
|
||||
type=int,
|
||||
help='Total number of iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--display_every',
|
||||
default=5,
|
||||
type=int,
|
||||
help='Display training information every N iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--validate_every',
|
||||
default=1000,
|
||||
type=int,
|
||||
help='Validate every N iterations')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--shuffle_data',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Activate to shuffle the training data')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--optim',
|
||||
default='adam',
|
||||
type=str,
|
||||
help='The name of the optimizer to be used')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr',
|
||||
default=1e-3,
|
||||
type=float,
|
||||
help='Base learning rate')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--betas',
|
||||
default='0.9, 0.98',
|
||||
type=str,
|
||||
help='Adam optimizer\'s betas')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--eps',
|
||||
default='1e-9',
|
||||
type=float,
|
||||
help='Adam optimizer\'s epsilon')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr_decay_marks',
|
||||
default='50000, 55000',
|
||||
type=str,
|
||||
help='Learing rate decay marks')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--lr_decay_factor',
|
||||
default=0.5,
|
||||
type=float,
|
||||
help='Learning rate decay factor')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--weight_decay',
|
||||
default=1e-6,
|
||||
type=float,
|
||||
help='Weight decay')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--embedDim',
|
||||
default=300,
|
||||
type=int,
|
||||
help='Embedding dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--hiddenDim',
|
||||
default=512,
|
||||
type=int,
|
||||
help='LSTM hidden dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--numLayers',
|
||||
default=2,
|
||||
type=int,
|
||||
help='Number of hidden LSTM layers')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--dropout',
|
||||
default=0.1,
|
||||
type=float,
|
||||
help='Dropout value')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--multiHead',
|
||||
default=8,
|
||||
type=int,
|
||||
help='Number of attention heads')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--hiddenSizeHead',
|
||||
default=64,
|
||||
type=int,
|
||||
help='Dimension of each attention head')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FeedForwardSize',
|
||||
default=2048,
|
||||
type=int,
|
||||
help='Dimension of the feed forward layer')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatMLPSize',
|
||||
default=512,
|
||||
type=int,
|
||||
help='MLP flatten size')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatGlimpses',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Number of flatten glimpses')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--FlatOutSize',
|
||||
default=512,
|
||||
type=int,
|
||||
help='Final attention reduction dimension')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--layers',
|
||||
default=6,
|
||||
type=int,
|
||||
help='Number of self attention layers')
|
||||
|
||||
self.parser.add_argument(
|
||||
'--bidirectional',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Activate to use bidirectional LSTMs')
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def parse(self):
|
||||
# initialize parser
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
self.opts = self.parser.parse_args()
|
||||
|
||||
# parse gpu id list
|
||||
str_gpu_ids = self.opts.gpu_ids.split(',')
|
||||
self.opts.gpu_ids = []
|
||||
for str_id in str_gpu_ids:
|
||||
if str_id.isdigit() and int(str_id) >= 0:
|
||||
self.opts.gpu_ids.append(int(str_id))
|
||||
if len(self.opts.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
print('\n[INFO] Using {} CUDA device(s) ...'.format(
|
||||
len(self.opts.gpu_ids)))
|
||||
else:
|
||||
print('\n[INFO] Using cpu ...')
|
||||
self.opts.gpu_ids = []
|
||||
|
||||
# parse the optimizer's betas and lr decay marks
|
||||
self.opts.betas = [float(beta) for beta in self.opts.betas.split(',')]
|
||||
lr_decay_marks = [int(m) for m in self.opts.lr_decay_marks.split(',')]
|
||||
for i in range(1, len(lr_decay_marks)):
|
||||
assert lr_decay_marks[i] > lr_decay_marks[i-1]
|
||||
self.opts.lr_decay_marks = lr_decay_marks
|
||||
|
||||
# print and save options
|
||||
args = vars(self.opts)
|
||||
print('\n ' + 30*'-' + 'Opts' + 30*'-')
|
||||
for k, v in args.items():
|
||||
print('%s: %s' % (str(k), str(v)))
|
||||
|
||||
if not os.path.isdir(self.opts.run_dir):
|
||||
os.makedirs(self.opts.run_dir)
|
||||
filename = 'opts.txt'
|
||||
file_path = os.path.join(self.opts.run_dir, filename)
|
||||
with open(file_path, 'wt') as fout:
|
||||
fout.write('| options\n')
|
||||
for k, v in sorted(args.items()):
|
||||
fout.write('%s: %s\n' % (str(k), str(v)))
|
||||
return self.opts
|
280
prog_generator/train_caption_parser.py
Normal file
280
prog_generator/train_caption_parser.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
from clevrDialog_dataset import ClevrDialogCaptionDataset
|
||||
from models import SeqToSeqC, CaptionEncoder, Decoder
|
||||
from optim import get_optim, adjust_lr
|
||||
from options_caption_parser import Options
|
||||
import os, json, torch, pickle, copy, time
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as Data
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
class Execution:
|
||||
def __init__(self, opts):
|
||||
self.opts = opts
|
||||
|
||||
self.loss_fn = torch.nn.NLLLoss().cuda()
|
||||
|
||||
print("[INFO] Loading dataset ...")
|
||||
|
||||
self.dataset_tr = ClevrDialogCaptionDataset(
|
||||
opts.dataPathTr, opts.vocabPath, "train", "Captions Tr")
|
||||
|
||||
self.dataset_val = ClevrDialogCaptionDataset(
|
||||
opts.dataPathVal, opts.vocabPath, "val", "Captions Val")
|
||||
|
||||
self.dataset_test = ClevrDialogCaptionDataset(
|
||||
opts.dataPathTest, opts.vocabPath, "test", "Captions Test")
|
||||
|
||||
tb_path = os.path.join(opts.run_dir, "tb_logdir")
|
||||
if not os.path.isdir(tb_path):
|
||||
os.makedirs(tb_path)
|
||||
|
||||
self.ckpt_path = os.path.join(opts.run_dir, "ckpt_dir")
|
||||
if not os.path.isdir(self.ckpt_path):
|
||||
os.makedirs(self.ckpt_path)
|
||||
|
||||
self.writer = SummaryWriter(tb_path)
|
||||
self.iter_val = 0
|
||||
self.bestValAcc = float("-inf")
|
||||
self.bestValIter = -1
|
||||
|
||||
def constructNet(self, lenVocabText, lenVocabProg, maxLenProg, ):
|
||||
decoder = Decoder(self.opts, lenVocabProg, maxLenProg)
|
||||
encoder = CaptionEncoder(self.opts, lenVocabText)
|
||||
net = SeqToSeqC(encoder, decoder)
|
||||
return net
|
||||
|
||||
def train(self, dataset, dataset_val=None):
|
||||
# Obtain needed information
|
||||
lenVocabText = dataset.lenVocabText
|
||||
lenVocabProg = dataset.lenVocabProg
|
||||
maxLenProg = dataset.maxLenProg
|
||||
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
||||
|
||||
net.cuda()
|
||||
net.train()
|
||||
|
||||
# Define the multi-gpu training if needed
|
||||
if len(self.opts.gpu_ids) > 1:
|
||||
net = nn.DataParallel(net, device_ids=self.opts.gpu_ids)
|
||||
|
||||
# Load checkpoint if resume training
|
||||
if self.opts.load_checkpoint_path is not None:
|
||||
print("[INFO] Resume trainig from ckpt {} ...".format(
|
||||
self.opts.load_checkpoint_path
|
||||
))
|
||||
|
||||
# Load the network parameters
|
||||
ckpt = torch.load(self.opts.load_checkpoint_path)
|
||||
print("[INFO] Checkpoint successfully loaded ...")
|
||||
net.load_state_dict(ckpt['state_dict'])
|
||||
|
||||
# Load the optimizer paramters
|
||||
optim = get_optim(self.opts, net, len(dataset), lr_base=ckpt['lr_base'])
|
||||
optim.optimizer.load_state_dict(ckpt['optimizer'])
|
||||
|
||||
else:
|
||||
optim = get_optim(self.opts, net, len(dataset))
|
||||
_iter = 0
|
||||
epoch = 0
|
||||
|
||||
# Define dataloader
|
||||
dataloader = Data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=self.opts.shuffle_data,
|
||||
num_workers=self.opts.num_workers,
|
||||
)
|
||||
_iterCur = 0
|
||||
_totalCur = len(dataloader)
|
||||
# Training loop
|
||||
while _iter < self.opts.num_iters:
|
||||
# Learning Rate Decay
|
||||
if _iter in self.opts.lr_decay_marks:
|
||||
adjust_lr(optim, self.opts.lr_decay_factor)
|
||||
|
||||
time_start = time.time()
|
||||
# Iteration
|
||||
for caption, captionPrg in dataloader:
|
||||
if _iter >= self.opts.num_iters:
|
||||
break
|
||||
caption = caption.cuda()
|
||||
captionPrg = captionPrg.cuda()
|
||||
captionPrgTarget = captionPrg.clone()
|
||||
optim.zero_grad()
|
||||
|
||||
predSoftmax, _ = net(caption, captionPrg)
|
||||
|
||||
loss = self.loss_fn(
|
||||
predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
|
||||
captionPrgTarget[:, 1:].contiguous().view(-1))
|
||||
loss.backward()
|
||||
|
||||
# logging
|
||||
self.writer.add_scalar(
|
||||
'train/loss',
|
||||
loss.cpu().data.numpy(),
|
||||
global_step=_iter)
|
||||
|
||||
self.writer.add_scalar(
|
||||
'train/lr',
|
||||
optim._rate,
|
||||
global_step=_iter)
|
||||
if _iter % self.opts.display_every == 0:
|
||||
print("\r[CLEVR-Dialog - %s (%d/%4d)][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" % (
|
||||
dataset.name,
|
||||
_iterCur,
|
||||
_totalCur,
|
||||
epoch,
|
||||
_iter,
|
||||
self.opts.num_iters,
|
||||
loss.cpu().data.numpy(),
|
||||
optim._rate,
|
||||
), end=' ')
|
||||
optim.step()
|
||||
_iter += 1
|
||||
_iterCur += 1
|
||||
|
||||
if _iter % self.opts.validate_every == 0:
|
||||
if dataset_val is not None:
|
||||
valAcc = self.eval(
|
||||
net,
|
||||
dataset_val,
|
||||
valid=True,
|
||||
)
|
||||
if valAcc > self.bestValAcc:
|
||||
self.bestValAcc = valAcc
|
||||
self.bestValIter = _iter
|
||||
|
||||
print("[INFO] Checkpointing model @ iter {}".format(_iter))
|
||||
state = {
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optim.optimizer.state_dict(),
|
||||
'lr_base': optim.lr_base,
|
||||
'optim': optim.lr_base,
|
||||
'last_iter': _iter,
|
||||
'last_epoch': epoch,
|
||||
}
|
||||
# checkpointing
|
||||
torch.save(
|
||||
state,
|
||||
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
|
||||
)
|
||||
else:
|
||||
print("[INFO] No validation dataset available")
|
||||
|
||||
time_end = time.time()
|
||||
print('Finished epoch in {}s'.format(int(time_end-time_start)))
|
||||
epoch += 1
|
||||
|
||||
print("[INFO] Training done. Best model had val acc. {} @ iter {}...".format(self.bestValAcc, self.bestValIter))
|
||||
|
||||
# Evaluation
|
||||
def eval(self, net, dataset, valid=False):
|
||||
net = net.eval()
|
||||
data_size = len(dataset)
|
||||
dataloader = Data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.opts.num_workers,
|
||||
pin_memory=False
|
||||
)
|
||||
allPredictedProgs = []
|
||||
numAllProg = 0
|
||||
falsePred = 0
|
||||
for step, (caption, captionPrg) in enumerate(dataloader):
|
||||
print("\rEvaluation: [step %4d/%4d]" % (
|
||||
step,
|
||||
int(data_size / self.opts.batch_size),
|
||||
), end=' ')
|
||||
caption = caption.cuda()
|
||||
captionPrg = captionPrg.cuda()
|
||||
tokens = net.sample(caption)
|
||||
targetProgs = decodeProg(captionPrg, dataset.vocab["idx_prog_to_token"], target=True)
|
||||
predProgs = decodeProg(tokens, dataset.vocab["idx_prog_to_token"])
|
||||
allPredictedProgs.extend(list(map(lambda s: "( {} ( {} ) ) \n".format(s[0], ", ".join(s[1:])), predProgs)))
|
||||
numAllProg += len(targetProgs)
|
||||
for targetProg, predProg in zip(targetProgs, predProgs):
|
||||
mainMod = targetProg[0] == predProg[0]
|
||||
sameLength = len(targetProg) == len(predProg)
|
||||
sameArgs = False
|
||||
if sameLength:
|
||||
sameArgs = True
|
||||
for argTarget in targetProg[1:]:
|
||||
if argTarget not in predProg[1:]:
|
||||
sameArgs = False
|
||||
break
|
||||
|
||||
if not (mainMod and sameArgs):
|
||||
falsePred += 1
|
||||
val_acc = (1 - (falsePred / numAllProg)) * 100.0
|
||||
print("Acc: {}".format(val_acc))
|
||||
net = net.train()
|
||||
if not valid:
|
||||
with open(self.opts.res_path, "w") as f:
|
||||
f.writelines(allPredictedProgs)
|
||||
print("[INFO] Predicted caption programs logged into {}".format(self.opts.res_path))
|
||||
return val_acc
|
||||
|
||||
def run(self, run_mode):
|
||||
self.set_seed(self.opts.seed)
|
||||
if run_mode == 'train':
|
||||
self.train(self.dataset_tr, self.dataset_val)
|
||||
|
||||
elif run_mode == 'test':
|
||||
lenVocabText = self.dataset_test.lenVocabText
|
||||
lenVocabProg = self.dataset_test.lenVocabProg
|
||||
maxLenProg = self.dataset_test.maxLenProg
|
||||
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
||||
|
||||
print('Loading ckpt {}'.format(self.opts.load_checkpoint_path))
|
||||
state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
|
||||
net.load_state_dict(state_dict)
|
||||
net.cuda()
|
||||
self.eval(net, self.dataset_test)
|
||||
|
||||
else:
|
||||
exit(-1)
|
||||
|
||||
def set_seed(self, seed):
|
||||
"""Sets the seed for reproducibility.
|
||||
Args:
|
||||
seed (int): The seed used
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(seed)
|
||||
print('[INFO] Seed set to {}...'.format(seed))
|
||||
|
||||
|
||||
def decodeProg(tokens, prgIdxToToken, target=False):
|
||||
tokensBatch = tokens.tolist()
|
||||
progsBatch = []
|
||||
for tokens in tokensBatch:
|
||||
prog = []
|
||||
for tok in tokens:
|
||||
if tok == 2: # <END> has index 2
|
||||
break
|
||||
prog.append(prgIdxToToken.get(tok))
|
||||
if target:
|
||||
prog = prog[1:]
|
||||
progsBatch.append(prog)
|
||||
return progsBatch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opts = Options().parse()
|
||||
exe = Execution(opts)
|
||||
exe.run(opts.mode)
|
||||
print("[INFO] Done ...")
|
912
prog_generator/train_question_parser.py
Normal file
912
prog_generator/train_question_parser.py
Normal file
|
@ -0,0 +1,912 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json, torch, pickle, copy, time
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as Data
|
||||
from tensorboardX import SummaryWriter
|
||||
from copy import deepcopy
|
||||
from clevrDialog_dataset import ClevrDialogQuestionDataset
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from executor.symbolic_executor import SymbolicExecutorClevr, SymbolicExecutorMinecraft
|
||||
from models import SeqToSeqQ, QuestEncoder_1, QuestEncoder_2, Decoder, CaptionEncoder, SeqToSeqC
|
||||
from optim import get_optim, adjust_lr
|
||||
from options_caption_parser import Options as OptionsC
|
||||
from options_question_parser import Options as OptionsQ
|
||||
|
||||
|
||||
class Execution:
|
||||
def __init__(self, optsQ, optsC):
|
||||
self.opts = deepcopy(optsQ)
|
||||
if self.opts.useCuda > 0 and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda:0")
|
||||
print("[INFO] Using GPU {} ...".format(torch.cuda.get_device_name(0)))
|
||||
else:
|
||||
print("[INFO] Using CPU ...")
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.loss_fn = torch.nn.NLLLoss().to(self.device)
|
||||
|
||||
print("[INFO] Loading dataset ...")
|
||||
|
||||
self.datasetTr = ClevrDialogQuestionDataset(
|
||||
self.opts.dataPathTr, self.opts.vocabPath, "train", "All tr data")
|
||||
|
||||
self.datasetVal = ClevrDialogQuestionDataset(
|
||||
self.opts.dataPathVal, self.opts.vocabPath, "val", "All val data", train=False)
|
||||
|
||||
self.datasetTest = ClevrDialogQuestionDataset(
|
||||
self.opts.dataPathTest, self.opts.vocabPath, "test", "All val data", train=False)
|
||||
|
||||
self.QuestionNet = constructQuestionNet(
|
||||
self.opts,
|
||||
self.datasetTr.lenVocabText,
|
||||
self.datasetTr.lenVocabProg,
|
||||
self.datasetTr.maxLenProg,
|
||||
)
|
||||
|
||||
if os.path.isfile(self.opts.captionNetPath):
|
||||
self.CaptionNet = constructCaptionNet(
|
||||
optsC,
|
||||
self.datasetTr.lenVocabText,
|
||||
self.datasetTr.lenVocabProg,
|
||||
self.datasetTr.maxLenProg
|
||||
)
|
||||
print('Loading CaptionNet from {}'.format(self.opts.captionNetPath))
|
||||
state_dict = torch.load(self.opts.captionNetPath)['state_dict']
|
||||
self.CaptionNet.load_state_dict(state_dict)
|
||||
self.CaptionNet.to(self.device)
|
||||
total_params_cap = sum(p.numel() for p in self.CaptionNet.parameters() if p.requires_grad)
|
||||
print("The caption encoder has {} trainable parameters".format(total_params_cap))
|
||||
|
||||
self.QuestionNet.to(self.device)
|
||||
# if os.path.isfile(self.opts.load_checkpoint_path):
|
||||
# print('Loading QuestionNet from {}'.format(optsQ.load_checkpoint_path))
|
||||
# state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
|
||||
# self.QuestionNet.load_state_dict(state_dict)
|
||||
total_params_quest = sum(p.numel() for p in self.QuestionNet.parameters() if p.requires_grad)
|
||||
print("The question encoder has {} trainable parameters".format(total_params_quest))
|
||||
|
||||
if "minecraft" in self.opts.scenesPath:
|
||||
self.symbolicExecutor = SymbolicExecutorMinecraft(self.opts.scenesPath)
|
||||
else:
|
||||
self.symbolicExecutor = SymbolicExecutorClevr(self.opts.scenesPath)
|
||||
|
||||
tb_path = os.path.join(self.opts.run_dir, "tb_logdir")
|
||||
if not os.path.isdir(tb_path):
|
||||
os.makedirs(tb_path)
|
||||
|
||||
self.ckpt_path = os.path.join(self.opts.run_dir, "ckpt_dir")
|
||||
if not os.path.isdir(self.ckpt_path):
|
||||
os.makedirs(self.ckpt_path)
|
||||
if not os.path.isdir(self.opts.text_log_dir):
|
||||
os.makedirs(self.opts.text_log_dir)
|
||||
|
||||
self.writer = SummaryWriter(tb_path)
|
||||
self.iter_val = 0
|
||||
|
||||
if os.path.isfile(self.opts.dependenciesPath):
|
||||
with open(self.opts.dependenciesPath, "rb") as f:
|
||||
self.dependencies = pickle.load(f)
|
||||
|
||||
def train(self):
|
||||
self.QuestionNet.train()
|
||||
|
||||
# Define the multi-gpu training if needed
|
||||
if len(self.opts.gpu_ids) > 1:
|
||||
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||
|
||||
# Load checkpoint if resume training
|
||||
if os.path.isfile(self.opts.load_checkpoint_path):
|
||||
print("[INFO] Resume trainig from ckpt {} ...".format(
|
||||
self.opts.load_checkpoint_path
|
||||
))
|
||||
|
||||
# Load the network parameters
|
||||
ckpt = torch.load(self.opts.load_checkpoint_path)
|
||||
print("[INFO] Checkpoint successfully loaded ...")
|
||||
self.QuestionNet.load_state_dict(ckpt['state_dict'])
|
||||
|
||||
# Load the optimizer paramters
|
||||
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr)) # , ckpt['optim'], lr_base=ckpt['lr_base'])
|
||||
# optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH)
|
||||
optim.optimizer.load_state_dict(ckpt['optimizer'])
|
||||
_iter = 0 # ckpt['last_iter']
|
||||
epoch = 0 # ckpt['last_epoch']
|
||||
|
||||
else:
|
||||
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr))
|
||||
_iter = 0
|
||||
epoch = 0
|
||||
|
||||
trainTime = 0
|
||||
bestValAcc = float("-inf")
|
||||
bestCkp = 0
|
||||
# Training loop
|
||||
while _iter < self.opts.num_iters:
|
||||
|
||||
# Learning Rate Decay
|
||||
if _iter in self.opts.lr_decay_marks:
|
||||
adjust_lr(optim, self.opts.lr_decay_factor)
|
||||
|
||||
# Define multi-thread dataloader
|
||||
dataloader = Data.DataLoader(
|
||||
self.datasetTr,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=self.opts.shuffle_data,
|
||||
num_workers=self.opts.num_workers,
|
||||
)
|
||||
|
||||
# Iteration
|
||||
time_start = 0
|
||||
time_end = 0
|
||||
for batch_iter, (quest, hist, prog, questionRound, _) in enumerate(dataloader):
|
||||
time_start = time.time()
|
||||
if _iter >= self.opts.num_iters:
|
||||
break
|
||||
quest = quest.to(self.device)
|
||||
if self.opts.last_n_rounds < 10:
|
||||
last_n_rounds_batch = []
|
||||
for i, r in enumerate(questionRound.tolist()):
|
||||
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||
endIdx = max(r, self.opts.last_n_rounds)
|
||||
if hist.dim() == 3:
|
||||
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||
histBatch = hist[i, :, :]
|
||||
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||
elif hist.dim() == 2:
|
||||
startIdx *= 20
|
||||
endIdx *= 20
|
||||
histBatch = hist[i, :]
|
||||
temp = histBatch[startIdx:endIdx].cpu()
|
||||
if r > self.opts.last_n_rounds:
|
||||
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||
else:
|
||||
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||
hist = torch.stack(last_n_rounds_batch, dim=0)
|
||||
hist = hist.to(self.device)
|
||||
prog = prog.to(self.device)
|
||||
progTarget = prog.clone()
|
||||
optim.zero_grad()
|
||||
|
||||
predSoftmax, _ = self.QuestionNet(quest, hist, prog[:, :-1])
|
||||
loss = self.loss_fn(
|
||||
# predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
|
||||
predSoftmax.contiguous().view(-1, predSoftmax.size(2)),
|
||||
progTarget[:, 1:].contiguous().view(-1))
|
||||
loss.backward()
|
||||
|
||||
if _iter % self.opts.validate_every == 0 and _iter > 0:
|
||||
valAcc = self.val()
|
||||
if valAcc > bestValAcc:
|
||||
bestValAcc = valAcc
|
||||
bestCkp = _iter
|
||||
print("\n[INFO] Checkpointing model @ iter {} with val accuracy {}\n".format(_iter, valAcc))
|
||||
state = {
|
||||
'state_dict': self.QuestionNet.state_dict(),
|
||||
'optimizer': optim.optimizer.state_dict(),
|
||||
'lr_base': optim.lr_base,
|
||||
'optim': optim.lr_base,
|
||||
'last_iter': _iter,
|
||||
'last_epoch': epoch,
|
||||
}
|
||||
# checkpointing
|
||||
torch.save(
|
||||
state,
|
||||
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
|
||||
)
|
||||
|
||||
# logging
|
||||
self.writer.add_scalar(
|
||||
'train/loss',
|
||||
loss.cpu().data.numpy(),
|
||||
global_step=_iter)
|
||||
|
||||
self.writer.add_scalar(
|
||||
'train/lr',
|
||||
optim._rate,
|
||||
global_step=_iter)
|
||||
if _iter % self.opts.display_every == 0:
|
||||
time_end = time.time()
|
||||
trainTime += time_end-time_start
|
||||
|
||||
print("\r[CLEVR-Dialog - %s (%d | %d)][epoch %2d][iter %4d/%4d][runtime %4f] loss: %.4f, lr: %.2e" % (
|
||||
self.datasetTr.name,
|
||||
batch_iter,
|
||||
len(dataloader),
|
||||
epoch,
|
||||
_iter,
|
||||
self.opts.num_iters,
|
||||
trainTime,
|
||||
loss.cpu().data.numpy(),
|
||||
optim._rate,
|
||||
), end=' ')
|
||||
|
||||
optim.step()
|
||||
_iter += 1
|
||||
|
||||
epoch += 1
|
||||
print("[INFO] Avg. epoch time: {} s".format(trainTime / epoch))
|
||||
print("[INFO] Best model achieved val acc. {} @ iter {}".format(bestValAcc, bestCkp))
|
||||
|
||||
def val(self):
|
||||
self.QuestionNet.eval()
|
||||
|
||||
total_correct = 0
|
||||
total = 0
|
||||
|
||||
if len(self.opts.gpu_ids) > 1:
|
||||
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||
self.QuestionNet = self.QuestionNet.eval()
|
||||
dataloader = Data.DataLoader(
|
||||
self.datasetVal,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.opts.num_workers,
|
||||
pin_memory=False
|
||||
)
|
||||
_iterCur = 0
|
||||
_totalCur = len(dataloader)
|
||||
|
||||
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||
# print("\rEvaluation: [step %4d/%4d]" % (
|
||||
print("\rEvaluation: [step %4d/%4d]" % (
|
||||
step,
|
||||
int(len(dataloader)),
|
||||
), end=' ')
|
||||
|
||||
question = question.to(self.device)
|
||||
|
||||
if history.dim() == 3:
|
||||
caption = history.detach()
|
||||
caption = caption[:, 0, :]
|
||||
caption = caption[:, :16].to(self.device)
|
||||
elif history.dim() == 2:
|
||||
caption = history.detach()
|
||||
caption = caption[:, :16].to(self.device)
|
||||
if self.opts.last_n_rounds is not None:
|
||||
last_n_rounds_batch = []
|
||||
for i, r in enumerate(questionRounds.tolist()):
|
||||
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||
endIdx = max(r, self.opts.last_n_rounds)
|
||||
if history.dim() == 3:
|
||||
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||
histBatch = history[i, :, :]
|
||||
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||
elif history.dim() == 2:
|
||||
startIdx *= 20
|
||||
endIdx *= 20
|
||||
histBatch = history[i, :]
|
||||
temp = histBatch[startIdx:endIdx]
|
||||
if r > self.opts.last_n_rounds:
|
||||
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||
else:
|
||||
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||
history = torch.stack(last_n_rounds_batch, dim=0)
|
||||
history = history.to(self.device)
|
||||
questionPrg = questionPrg.to(self.device)
|
||||
|
||||
questProgsToksPred = self.QuestionNet.sample(question, history)
|
||||
questProgsPred = decodeProg(questProgsToksPred, self.datasetVal.vocab["idx_prog_to_token"])
|
||||
targetProgs = decodeProg(questionPrg, self.datasetVal.vocab["idx_prog_to_token"], target=True)
|
||||
|
||||
correct = [1 if pred == gt else 0 for (pred, gt) in zip(questProgsPred, targetProgs)]
|
||||
|
||||
correct = sum(correct)
|
||||
total_correct += correct
|
||||
total += len(targetProgs)
|
||||
self.QuestionNet.train()
|
||||
|
||||
return 100.0 * (total_correct / total)
|
||||
|
||||
# Evaluation
|
||||
def eval_with_gt(self):
|
||||
# Define the multi-gpu training if needed
|
||||
all_pred_answers = []
|
||||
all_gt_answers = []
|
||||
all_question_types = []
|
||||
all_penalties = []
|
||||
all_pred_programs = []
|
||||
all_gt_programs = []
|
||||
|
||||
first_failure_round = 0
|
||||
total_correct = 0
|
||||
total_acc_pen = 0
|
||||
total = 0
|
||||
total_quest_prog_correct = 0
|
||||
|
||||
if len(self.opts.gpu_ids) > 1:
|
||||
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||
self.QuestionNet = self.QuestionNet.eval()
|
||||
self.CaptionNet = self.CaptionNet.eval()
|
||||
if self.opts.batch_size != self.opts.dialogLen:
|
||||
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
|
||||
self.opts.batch_size = self.opts.dialogLen
|
||||
dataloader = Data.DataLoader(
|
||||
self.datasetTest,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.opts.num_workers,
|
||||
pin_memory=False
|
||||
)
|
||||
_iterCur = 0
|
||||
_totalCur = len(dataloader)
|
||||
|
||||
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||
# print("\rEvaluation: [step %4d/%4d]" % (
|
||||
# step + 1,
|
||||
# int(data_size / self.opts.batch_size),
|
||||
# ), end=' ')
|
||||
# if step >= 5000:
|
||||
# break
|
||||
batchSize = question.size(0)
|
||||
question = question.to(self.device)
|
||||
# dependecy = self.dependencies[step*batchSize:(step+1)*batchSize]
|
||||
|
||||
if history.dim() == 3:
|
||||
caption = history.detach()
|
||||
caption = caption[:, 0, :]
|
||||
caption = caption[:, :16].to(self.device)
|
||||
elif history.dim() == 2:
|
||||
caption = history.detach()
|
||||
caption = caption[:, :16].to(self.device)
|
||||
if self.opts.last_n_rounds < 10:
|
||||
last_n_rounds_batch = []
|
||||
for i, r in enumerate(questionRounds.tolist()):
|
||||
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||
endIdx = max(r, self.opts.last_n_rounds)
|
||||
if history.dim() == 3:
|
||||
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||
histBatch = history[i, :, :]
|
||||
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||
elif history.dim() == 2:
|
||||
startIdx *= 20
|
||||
endIdx *= 20
|
||||
histBatch = history[i, :]
|
||||
temp = histBatch[startIdx:endIdx]
|
||||
if r > self.opts.last_n_rounds:
|
||||
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||
else:
|
||||
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||
history = torch.stack(last_n_rounds_batch, dim=0)
|
||||
|
||||
history = history.to(self.device)
|
||||
questionPrg = questionPrg.to(self.device)
|
||||
historiesProg = historiesProg.tolist()
|
||||
questionRounds = questionRounds.tolist()
|
||||
answer = answer.tolist()
|
||||
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
|
||||
questionImgIdx = questionImgIdx.tolist()
|
||||
# if "minecraft" in self.opts.scenesPath:
|
||||
# questionImgIdx = [idx - 1 for idx in questionImgIdx]
|
||||
questProgsToksPred = self.QuestionNet.sample(question, history)
|
||||
capProgsToksPred = self.CaptionNet.sample(caption)
|
||||
|
||||
questProgsPred = decodeProg(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
|
||||
capProgsPred = decodeProg(capProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
|
||||
|
||||
targetProgs = decodeProg(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
|
||||
questionTypes = [targetProg[0] for targetProg in targetProgs]
|
||||
# progHistories = getProgHistories(historiesProg[0], dataset.vocab["idx_prog_to_token"])
|
||||
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
|
||||
pred_answers = []
|
||||
all_pred_programs.append([capProgsPred[0]] + questProgsPred)
|
||||
all_gt_programs.append([progHistories[0]] + (targetProgs))
|
||||
|
||||
for i in range(batchSize):
|
||||
# if capProgsPred[i][0] == "extreme-center":
|
||||
# print("bla")
|
||||
# print("idx = {}".format(questionImgIdx[i]))
|
||||
ans = self.getPrediction(
|
||||
questProgsPred[i],
|
||||
capProgsPred[i],
|
||||
progHistories[i],
|
||||
questionImgIdx[i]
|
||||
)
|
||||
# if ans == "Error":
|
||||
# print(capProgsPred[i])
|
||||
pred_answers.append(ans)
|
||||
# print(pred_answers)
|
||||
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
|
||||
correct_prog = [1 if pred == ans else 0 for (pred, ans) in zip(questProgsPred, targetProgs)]
|
||||
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
|
||||
if idx_false.shape[-1] > 0:
|
||||
first_failure_round += idx_false[0] + 1
|
||||
else:
|
||||
first_failure_round += self.opts.dialogLen + 1
|
||||
|
||||
correct = sum(correct)
|
||||
correct_prog = sum(correct_prog)
|
||||
total_correct += correct
|
||||
total_quest_prog_correct += correct_prog
|
||||
total += len(answers)
|
||||
all_pred_answers.append(pred_answers)
|
||||
all_gt_answers.append(answers)
|
||||
all_question_types.append(questionTypes)
|
||||
penalty = np.zeros_like(penalty)
|
||||
all_penalties.append(penalty)
|
||||
_iterCur += 1
|
||||
if _iterCur % self.opts.display_every == 0:
|
||||
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
|
||||
_iterCur, _totalCur, 100.0 * (total_correct / total)))
|
||||
|
||||
ffr = 1.0 * (first_failure_round/_totalCur)/(self.opts.dialogLen + 1)
|
||||
|
||||
textOut = "\n --------------- Average First Failure Round --------------- \n"
|
||||
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
|
||||
|
||||
# print(total_correct, total)
|
||||
accuracy = total_correct / total
|
||||
vd_acc = total_acc_pen / total
|
||||
quest_prog_acc = total_quest_prog_correct / total
|
||||
textOut += "\n --------------- Overall acc. --------------- \n"
|
||||
textOut += "{}".format(100.0 * accuracy)
|
||||
textOut += "\n --------------- Overall VD acc. --------------- \n"
|
||||
textOut += "{}".format(100.0 * vd_acc)
|
||||
textOut += "\n --------------- Question Prog. Acc --------------- \n"
|
||||
textOut += "{}".format(100.0 * quest_prog_acc)
|
||||
textOut += get_per_round_acc(
|
||||
all_pred_answers, all_gt_answers, all_penalties)
|
||||
|
||||
textOut += get_per_question_type_acc(
|
||||
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
|
||||
|
||||
# textOut += get_per_dependency_type_acc(
|
||||
# all_pred_answers, all_gt_answers, all_penalties)
|
||||
|
||||
textOut += "\n --------------- Done --------------- \n"
|
||||
print(textOut)
|
||||
fname = self.opts.questionNetPath.split("/")[-3] + "results_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
|
||||
model_name = "NSVD_stack" if "stack" in self.opts.questionNetPath else "NSVD_concat"
|
||||
experiment_name = "minecraft"
|
||||
# experiment_name += "_{}".format(self.opts.dialogLen)
|
||||
prog_output_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/prog_output/{}_{}.pkl".format(model_name, experiment_name))
|
||||
|
||||
fpath = os.path.join(self.opts.text_log_dir, fname)
|
||||
with open(fpath, "w") as f:
|
||||
f.writelines(textOut)
|
||||
with open(pred_answers_fname, "wb") as f:
|
||||
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with open(prog_output_fname, "wb") as f:
|
||||
pickle.dump((all_gt_programs, all_pred_programs, all_pred_answers), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Evaluation
|
||||
def eval_with_pred(self):
|
||||
# Define the multi-gpu training if needed
|
||||
all_pred_answers = []
|
||||
all_gt_answers = []
|
||||
all_question_types = []
|
||||
all_penalties = []
|
||||
|
||||
first_failure_round = 0
|
||||
total_correct = 0
|
||||
total_acc_pen = 0
|
||||
total = 0
|
||||
|
||||
samples = {}
|
||||
|
||||
if len(self.opts.gpu_ids) > 1:
|
||||
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||
self.QuestionNet = self.QuestionNet.eval()
|
||||
self.CaptionNet = self.CaptionNet.eval()
|
||||
if self.opts.batch_size != self.opts.dialogLen:
|
||||
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
|
||||
self.opts.batch_size = self.opts.dialogLen
|
||||
dataloader = Data.DataLoader(
|
||||
self.datasetTest,
|
||||
batch_size=self.opts.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.opts.num_workers,
|
||||
pin_memory=False
|
||||
)
|
||||
_iterCur = 0
|
||||
_totalCur = len(dataloader)
|
||||
step = 0
|
||||
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||
question = question.tolist()
|
||||
questions = decode(question, self.datasetTest.vocab["idx_text_to_token"], target=True)
|
||||
questions = list(map(lambda q: " ".join(q), questions))
|
||||
targetProgs = decode(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
|
||||
|
||||
questionTypes = [targetProg[0] for targetProg in targetProgs]
|
||||
targetProgs = list(map(lambda q: " ".join(q), targetProgs))
|
||||
|
||||
historiesProg = historiesProg.tolist()
|
||||
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
|
||||
|
||||
answer = answer.tolist()
|
||||
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
|
||||
questionImgIdx = questionImgIdx.tolist()
|
||||
|
||||
if self.opts.encoderType == 2:
|
||||
histories_eval = [history[0, 0, :].tolist()]
|
||||
caption = history.detach()
|
||||
caption = caption[0, 0, :].unsqueeze(0)
|
||||
caption = caption[:, :16].to(self.device)
|
||||
elif self.opts.encoderType == 1:
|
||||
caption = history.detach()
|
||||
histories_eval = [history[0, :20].tolist()]
|
||||
caption = caption[0, :16].unsqueeze(0).to(self.device)
|
||||
cap = decode(caption, self.datasetTest.vocab["idx_text_to_token"], target=False)
|
||||
capProgToksPred = self.CaptionNet.sample(caption)
|
||||
capProgPred = decode(capProgToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
|
||||
|
||||
pred_answers = []
|
||||
pred_quest_prog = []
|
||||
for i, (q, prog_hist, img_idx) in enumerate(zip(question, progHistories, questionImgIdx)):
|
||||
_round = i + 1
|
||||
if _round <= self.opts.last_n_rounds:
|
||||
start = 0
|
||||
else:
|
||||
start = _round - self.opts.last_n_rounds
|
||||
end = len(histories_eval)
|
||||
|
||||
quest = torch.tensor(q).unsqueeze(0).to(self.device)
|
||||
if self.opts.encoderType == 3:
|
||||
hist = torch.stack([torch.tensor(h) for h in histories_eval[start:end]], dim=0).unsqueeze(0).to(self.device)
|
||||
elif self.opts.encoderType == 1:
|
||||
histories_eval_copy = deepcopy(histories_eval)
|
||||
histories_eval_copy[-1].append(self.datasetTest.vocab["text_token_to_idx"]["<END>"])
|
||||
hist = torch.cat([torch.tensor(h) for h in histories_eval_copy[start:end]], dim=-1).unsqueeze(0).to(self.device)
|
||||
|
||||
questProgsToksPred = self.QuestionNet.sample(quest, hist)
|
||||
questProgsPred = decode(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
|
||||
pred_quest_prog.append(" ".join(questProgsPred))
|
||||
ans = self.getPrediction(
|
||||
questProgsPred,
|
||||
capProgPred,
|
||||
prog_hist,
|
||||
img_idx
|
||||
)
|
||||
ans_idx = self.datasetTest.vocab["text_token_to_idx"].get(
|
||||
ans, self.datasetTest.vocab["text_token_to_idx"]["<UNK>"])
|
||||
q[q.index(self.datasetTest.vocab["text_token_to_idx"]["<END>"])] = self.datasetTest.vocab["text_token_to_idx"]["<NULL>"]
|
||||
q[-1] = self.datasetTest.vocab["text_token_to_idx"]["<END>"]
|
||||
q.insert(-1, ans_idx)
|
||||
if self.opts.encoderType == 3:
|
||||
histories_eval.append(copy.deepcopy(q))
|
||||
elif self.opts.encoderType == 0:
|
||||
del q[0]
|
||||
del q[-1]
|
||||
histories_eval.append(copy.deepcopy(q))
|
||||
|
||||
pred_answers.append(ans)
|
||||
|
||||
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
|
||||
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
|
||||
if idx_false.shape[-1] > 0:
|
||||
first_failure_round += idx_false[0] + 1
|
||||
else:
|
||||
first_failure_round += self.opts.dialogLen + 1
|
||||
|
||||
correct = sum(correct)
|
||||
total_correct += correct
|
||||
total += len(answers)
|
||||
all_pred_answers.append(pred_answers)
|
||||
all_gt_answers.append(answers)
|
||||
all_question_types.append(questionTypes)
|
||||
_iterCur += 1
|
||||
if _iterCur % self.opts.display_every == 0:
|
||||
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
|
||||
_iterCur, _totalCur, 100.0 * (total_correct / total)
|
||||
))
|
||||
samples["{}_{}".format(questionImgIdx[0], (step % 5) + 1)] = {
|
||||
"caption": " ".join(cap[0]),
|
||||
"cap_prog_gt": " ".join(progHistories[0][0]),
|
||||
"cap_prog_pred": " ".join(capProgPred),
|
||||
|
||||
"questions": questions,
|
||||
"quest_progs_gt": targetProgs,
|
||||
"quest_progs_pred": pred_quest_prog,
|
||||
|
||||
|
||||
"answers": answers,
|
||||
"preds": pred_answers,
|
||||
"acc": correct,
|
||||
}
|
||||
|
||||
|
||||
ffr = 1.0 * self.opts.dialogLen * (first_failure_round/total)
|
||||
|
||||
textOut = "\n --------------- Average First Failure Round --------------- \n"
|
||||
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
|
||||
|
||||
# print(total_correct, total)
|
||||
accuracy = total_correct / total
|
||||
vd_acc = total_acc_pen / total
|
||||
textOut += "\n --------------- Overall acc. --------------- \n"
|
||||
textOut += "{}".format(100.0 * accuracy)
|
||||
textOut += "\n --------------- Overall VD acc. --------------- \n"
|
||||
textOut += "{}".format(100.0 * vd_acc)
|
||||
|
||||
textOut += get_per_round_acc(
|
||||
all_pred_answers, all_gt_answers, all_penalties)
|
||||
|
||||
textOut += get_per_question_type_acc(
|
||||
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
|
||||
|
||||
textOut += "\n --------------- Done --------------- \n"
|
||||
print(textOut)
|
||||
if step >= len(dataloader):
|
||||
fname = self.opts.questionNetPath.split("/")[-3] + "_results_{}_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen, self.acc_type)
|
||||
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
|
||||
|
||||
fpath = os.path.join(self.opts.text_log_dir, fname)
|
||||
with open(fpath, "w") as f:
|
||||
f.writelines(textOut)
|
||||
with open(pred_answers_fname, "wb") as f:
|
||||
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def getPrediction(self, questProgPred, capProgPred, historyProg, imgIndex):
|
||||
self.symbolicExecutor.reset(imgIndex)
|
||||
# if round one, execute the predicted caption program first then answer the question
|
||||
if len(historyProg) == 1:
|
||||
captionFuncLabel = capProgPred[0]
|
||||
captionFuncArgs = capProgPred[1:]
|
||||
|
||||
questionFuncLabel = questProgPred[0]
|
||||
questionFuncArgs = questProgPred[1:]
|
||||
|
||||
try:
|
||||
_ = self.symbolicExecutor.execute(captionFuncLabel, captionFuncArgs)
|
||||
except:
|
||||
return "Error"
|
||||
|
||||
try:
|
||||
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
|
||||
except:
|
||||
return "Error"
|
||||
|
||||
# If it is not the first round, we have to execute the program history and
|
||||
# then answer the question.
|
||||
else:
|
||||
questionFuncLabel = questProgPred[0]
|
||||
questionFuncArgs = questProgPred[1:]
|
||||
for prg in historyProg:
|
||||
# prg = prg.split(" ")
|
||||
FuncLabel = prg[0]
|
||||
FuncArgs = prg[1:]
|
||||
try:
|
||||
_ = self.symbolicExecutor.execute(FuncLabel, FuncArgs)
|
||||
except:
|
||||
return "Error"
|
||||
|
||||
try:
|
||||
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
|
||||
except:
|
||||
return "Error"
|
||||
return str(predAnswer)
|
||||
|
||||
def run(self, run_mode, epoch=None):
|
||||
self.set_seed(self.opts.seed)
|
||||
if run_mode == 'train':
|
||||
self.train()
|
||||
|
||||
elif run_mode == 'test_with_gt':
|
||||
print('Testing with gt answers in history')
|
||||
print('Loading ckpt {}'.format(self.opts.questionNetPath))
|
||||
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
|
||||
self.QuestionNet.load_state_dict(state_dict)
|
||||
self.eval_with_gt()
|
||||
|
||||
elif run_mode == 'test_with_pred':
|
||||
print('Testing with predicted answers in history')
|
||||
print('Loading ckpt {}'.format(self.opts.questionNetPath))
|
||||
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
|
||||
self.QuestionNet.load_state_dict(state_dict)
|
||||
self.eval_with_pred()
|
||||
else:
|
||||
exit(-1)
|
||||
|
||||
def set_seed(self, seed):
|
||||
"""Sets the seed for reproducibility.
|
||||
Args:
|
||||
seed (int): The seed used
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(seed)
|
||||
print('[INFO] Seed set to {}...'.format(seed))
|
||||
|
||||
|
||||
def constructQuestionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
|
||||
decoder = Decoder(opts, lenVocabProg, maxLenProg)
|
||||
if opts.encoderType == 1:
|
||||
encoder = QuestEncoder_1(opts, lenVocabText)
|
||||
elif opts.encoderType == 2:
|
||||
encoder = QuestEncoder_2(opts, lenVocabText)
|
||||
|
||||
net = SeqToSeqQ(encoder, decoder)
|
||||
return net
|
||||
|
||||
|
||||
def constructCaptionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
|
||||
decoder = Decoder(opts, lenVocabProg, maxLenProg)
|
||||
encoder = CaptionEncoder(opts, lenVocabText)
|
||||
net = SeqToSeqC(encoder, decoder)
|
||||
return net
|
||||
|
||||
|
||||
def getProgHistories(progHistToks, prgIdxToToken):
|
||||
progHist = []
|
||||
temp = []
|
||||
for tok in progHistToks:
|
||||
if tok not in [0, 1, 2]:
|
||||
temp.append(prgIdxToToken[tok])
|
||||
# del progHistToks[i]
|
||||
if tok == 2:
|
||||
# del progHistToks[i]
|
||||
# progHist.append(" ".join(temp))
|
||||
progHist.append(temp)
|
||||
temp = []
|
||||
return progHist
|
||||
|
||||
|
||||
def getHistoriesFromStack(histToks, textIdxToToken):
|
||||
histories = "\n"
|
||||
temp = []
|
||||
for i, roundToks in enumerate(histToks):
|
||||
for tok in roundToks:
|
||||
if tok not in [0, 1, 2]:
|
||||
temp.append(textIdxToToken[tok])
|
||||
# del progHistToks[i]
|
||||
if tok == 2:
|
||||
# del progHistToks[i]
|
||||
if i == 0:
|
||||
histories += " ".join(temp) + ".\n"
|
||||
else:
|
||||
histories += " ".join(temp[:-1]) + "? | {}.\n".format(temp[-1])
|
||||
# histories.append(temp)
|
||||
temp = []
|
||||
break
|
||||
return histories
|
||||
|
||||
|
||||
def getHistoriesFromConcat(histToks, textIdxToToken):
|
||||
histories = []
|
||||
temp = []
|
||||
for tok in histToks:
|
||||
if tok not in [0, 1, 2]:
|
||||
temp.append(textIdxToToken[tok])
|
||||
# del progHistToks[i]
|
||||
if tok == 2:
|
||||
# del progHistToks[i]
|
||||
histories.append(" ".join(temp[:-1]) + "? | {}".format(temp[-1]))
|
||||
# histories.append(temp)
|
||||
temp = []
|
||||
return histories
|
||||
|
||||
|
||||
def decodeProg(tokens, prgIdxToToken, target=False):
|
||||
tokensBatch = tokens.tolist()
|
||||
progsBatch = []
|
||||
for tokens in tokensBatch:
|
||||
prog = []
|
||||
for tok in tokens:
|
||||
if tok == 2: # <END> has index 2
|
||||
break
|
||||
prog.append(prgIdxToToken.get(tok))
|
||||
if target:
|
||||
prog = prog[1:]
|
||||
# progsBatch.append(" ".join(prog))
|
||||
progsBatch.append(prog)
|
||||
return progsBatch
|
||||
|
||||
|
||||
def printPred(predSoftmax, gts, prgIdxToToken):
|
||||
assert predSoftmax.size(0) == gts.size(0)
|
||||
tokens = predSoftmax.topk(1)[1].squeeze(-1)
|
||||
tokens = tokens.tolist()
|
||||
gts = gts.tolist()
|
||||
message = "\n ------------------------ \n"
|
||||
for token, gt in zip(tokens, gts):
|
||||
message += "Prediction: "
|
||||
for tok in token:
|
||||
message += prgIdxToToken.get(tok) + " "
|
||||
message += "\n Target : "
|
||||
for tok in gt:
|
||||
message += prgIdxToToken.get(tok) + " "
|
||||
message += "\n ------------------------ \n"
|
||||
return message
|
||||
|
||||
|
||||
def get_per_round_acc(preds, gts, penalties):
|
||||
res = {}
|
||||
for img_preds, img_gt, img_pen in zip(preds, gts, penalties):
|
||||
img_preds = list(img_preds)
|
||||
img_gt = list(img_gt)
|
||||
img_pen = list(img_pen)
|
||||
for i, (pred, gt, pen) in enumerate(zip(img_preds, img_gt, img_pen)):
|
||||
_round = str(i + 1)
|
||||
if _round not in res:
|
||||
res[_round] = {
|
||||
"correct": 0,
|
||||
"all": 0
|
||||
}
|
||||
res[_round]["all"] += 1
|
||||
if pred == gt:
|
||||
res[_round]["correct"] += 0.5**pen
|
||||
|
||||
textOut = "\n --------------- Per round Acc --------------- \n"
|
||||
for k in res:
|
||||
textOut += "{}: {} %\n".format(k, 100.0 * (res[k]["correct"]/res[k]["all"]))
|
||||
return textOut
|
||||
|
||||
|
||||
def get_per_question_type_acc(preds, gts, qtypes, penalties):
|
||||
res1 = {}
|
||||
res2 = {}
|
||||
|
||||
for img_preds, img_gt, img_qtypes, img_pen in zip(preds, gts, qtypes, penalties):
|
||||
# img_preds = list(img_preds)
|
||||
# img_gt = list(img_gt)
|
||||
img_pen = list(img_pen)
|
||||
for pred, gt, temp, pen in zip(img_preds, img_gt, img_qtypes, img_pen):
|
||||
if temp not in res1:
|
||||
res1[temp] = {
|
||||
"correct": 0,
|
||||
"all": 0
|
||||
}
|
||||
temp_cat = temp.split("-")[0]
|
||||
if temp_cat not in res2:
|
||||
res2[temp_cat] = {
|
||||
"correct": 0,
|
||||
"all": 0
|
||||
}
|
||||
res1[temp]["all"] += 1
|
||||
res2[temp_cat]["all"] += 1
|
||||
|
||||
if pred == gt:
|
||||
res1[temp]["correct"] += 0.5**pen
|
||||
res2[temp_cat]["correct"] += 0.5**pen
|
||||
|
||||
textOut = "\n --------------- Per question Type Acc --------------- \n"
|
||||
for k in res1:
|
||||
textOut += "{}: {} %\n".format(k, 100.0 * (res1[k]["correct"]/res1[k]["all"]))
|
||||
|
||||
textOut += "\n --------------- Per question Category Acc --------------- \n"
|
||||
for k in res2:
|
||||
textOut += "{}: {} %\n".format(k, 100.0 * (res2[k]["correct"]/res2[k]["all"]))
|
||||
return textOut
|
||||
|
||||
|
||||
def decode(tokens, prgIdxToToken, target=False):
|
||||
if type(tokens) != list:
|
||||
tokens = tokens.tolist()
|
||||
|
||||
progsBatch = []
|
||||
for token in tokens:
|
||||
prog = []
|
||||
for tok in token:
|
||||
if tok == 2: # <END> has index 2
|
||||
break
|
||||
prog.append(prgIdxToToken.get(tok))
|
||||
if target:
|
||||
prog = prog[1:]
|
||||
# progsBatch.append(" ".join(prog))
|
||||
progsBatch.append(prog)
|
||||
return progsBatch
|
||||
|
||||
if __name__ == "__main__":
|
||||
optsC = OptionsC().parse()
|
||||
optsQ = OptionsQ().parse()
|
||||
|
||||
exe = Execution(optsQ, optsC)
|
||||
exe.run("test")
|
||||
print("[INFO] Done ...")
|
Loading…
Add table
Add a link
Reference in a new issue