neuro-symbolic-visual-dialog/prog_generator/models.py

477 lines
14 KiB
Python

"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
import torch
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class FC(nn.Module):
def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
super(FC, self).__init__()
self.dropout_r = dropout_r
self.use_relu = use_relu
self.linear = nn.Linear(in_size, out_size)
if use_relu:
self.relu = nn.ReLU(inplace=True)
if dropout_r > 0:
self.dropout = nn.Dropout(dropout_r)
def forward(self, x):
x = self.linear(x)
if self.use_relu:
x = self.relu(x)
if self.dropout_r > 0:
x = self.dropout(x)
return x
class MLP(nn.Module):
def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
super(MLP, self).__init__()
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
self.linear = nn.Linear(mid_size, out_size)
def forward(self, x):
return self.linear(self.fc(x))
class LayerNorm(nn.Module):
def __init__(self, size, eps=1e-6):
super(LayerNorm, self).__init__()
self.eps = eps
self.a_2 = nn.Parameter(torch.ones(size))
self.b_2 = nn.Parameter(torch.zeros(size))
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class MHAtt(nn.Module):
def __init__(self, opts):
super(MHAtt, self).__init__()
self.opts = opts
self.linear_v = nn.Linear(opts.hiddenDim, opts.hiddenDim)
self.linear_k = nn.Linear(opts.hiddenDim, opts.hiddenDim)
self.linear_q = nn.Linear(opts.hiddenDim, opts.hiddenDim)
self.linear_merge = nn.Linear(opts.hiddenDim, opts.hiddenDim)
self.dropout = nn.Dropout(opts.dropout)
def forward(self, v, k, q, mask):
n_batches = q.size(0)
v = self.linear_v(v).view(
n_batches,
-1,
self.opts.multiHead,
self.opts.hiddenSizeHead
).transpose(1, 2)
k = self.linear_k(k).view(
n_batches,
-1,
self.opts.multiHead,
self.opts.hiddenSizeHead
).transpose(1, 2)
q = self.linear_q(q).view(
n_batches,
-1,
self.opts.multiHead,
self.opts.hiddenSizeHead
).transpose(1, 2)
atted = self.att(v, k, q, mask)
atted = atted.transpose(1, 2).contiguous().view(
n_batches,
-1,
self.opts.hiddenDim
)
atted = self.linear_merge(atted)
return atted
def att(self, value, key, query, mask):
d_k = query.size(-1)
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
att_map = F.softmax(scores, dim=-1)
att_map = self.dropout(att_map)
return torch.matmul(att_map, value)
class FFN(nn.Module):
def __init__(self, opts):
super(FFN, self).__init__()
self.mlp = MLP(
in_size=opts.hiddenDim,
mid_size=opts.FeedForwardSize,
out_size=opts.hiddenDim,
dropout_r=opts.dropout,
use_relu=True
)
def forward(self, x):
return self.mlp(x)
class SA(nn.Module):
def __init__(self, opts):
super(SA, self).__init__()
self.mhatt = MHAtt(opts)
self.ffn = FFN(opts)
self.dropout1 = nn.Dropout(opts.dropout)
self.norm1 = LayerNorm(opts.hiddenDim)
self.dropout2 = nn.Dropout(opts.dropout)
self.norm2 = LayerNorm(opts.hiddenDim)
def forward(self, x, x_mask):
x = self.norm1(x + self.dropout1(
self.mhatt(x, x, x, x_mask)
))
x = self.norm2(x + self.dropout2(
self.ffn(x)
))
return x
class AttFlat(nn.Module):
def __init__(self, opts):
super(AttFlat, self).__init__()
self.opts = opts
self.mlp = MLP(
in_size=opts.hiddenDim,
mid_size=opts.FlatMLPSize,
out_size=opts.FlatGlimpses,
dropout_r=opts.dropout,
use_relu=True
)
# FLAT_GLIMPSES = 1
self.linear_merge = nn.Linear(
opts.hiddenDim * opts.FlatGlimpses,
opts.FlatOutSize
)
def forward(self, x, x_mask):
att = self.mlp(x)
att = att.masked_fill(
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
-1e9
)
att = F.softmax(att, dim=1)
att_list = []
for i in range(self.opts.FlatGlimpses):
att_list.append(
torch.sum(att[:, :, i: i + 1] * x, dim=1)
)
x_atted = torch.cat(att_list, dim=1)
x_atted = self.linear_merge(x_atted)
return x_atted
class CaptionEncoder(nn.Module):
def __init__(self, opts, textVocabSize):
super(CaptionEncoder, self).__init__()
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
bidirectional = opts.bidirectional > 0
self.lstmC = nn.LSTM(
input_size=opts.embedDim,
hidden_size=opts.hiddenDim,
num_layers=opts.numLayers,
batch_first=True,
bidirectional=bidirectional
)
if bidirectional:
opts.hiddenDim *= 2
opts.hiddenSizeHead *= 2
opts.FlatOutSize *= 2
self.attCap = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
self.attFlatCap = AttFlat(opts)
self.fc = nn.Linear(opts.hiddenDim, opts.hiddenDim)
def forward(self, cap, hist=None):
capMask = self.make_mask(cap.unsqueeze(2))
cap = self.embedding(cap)
cap, (_, _) = self.lstmC(cap)
capO = cap.detach().clone()
for attC in self.attCap:
cap = attC(cap, capMask)
# (batchSize, 512)
cap = self.attFlatCap(cap, capMask)
encOut = self.fc(cap)
return encOut, capO
class QuestEncoder_1(nn.Module):
"""
Concat encoder
"""
def __init__(self, opts, textVocabSize):
super(QuestEncoder_1, self).__init__()
bidirectional = opts.bidirectional > 0
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
self.lstmQ = nn.LSTM(
input_size=opts.embedDim,
hidden_size=opts.hiddenDim,
num_layers=opts.numLayers,
bidirectional=bidirectional,
batch_first=True
)
self.lstmH = nn.LSTM(
input_size=opts.embedDim,
hidden_size=opts.hiddenDim,
num_layers=opts.numLayers,
bidirectional=bidirectional,
batch_first=True)
if bidirectional:
opts.hiddenDim *= 2
opts.hiddenSizeHead *= 2
opts.FlatOutSize *= 2
self.attQues = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
self.attHist = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
self.attFlatQuest = AttFlat(opts)
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
def forward(self, quest, hist):
questMask = self.make_mask(quest.unsqueeze(2))
histMask = self.make_mask(hist.unsqueeze(2))
# quest = F.tanh(self.embedding(quest))
quest = self.embedding(quest)
quest, (_, _) = self.lstmQ(quest)
questO = quest.detach().clone()
hist = self.embedding(hist)
hist, (_, _) = self.lstmH(hist)
for attQ, attH in zip(self.attQues, self.attHist):
quest = attQ(quest, questMask)
hist = attH(hist, histMask)
# (batchSize, 512)
quest = self.attFlatQuest(quest, questMask)
# hist: (batchSize, length, 512)
attWeights = torch.sum(torch.mul(hist, quest.unsqueeze(1)), -1)
attWeights = torch.softmax(attWeights, -1)
hist = torch.sum(torch.mul(hist, attWeights.unsqueeze(2)), 1)
encOut = self.fc(torch.cat([quest, hist], -1))
return encOut, questO
# Masking
def make_mask(self, feature):
return (torch.sum(
torch.abs(feature),
dim=-1
) == 0).unsqueeze(1).unsqueeze(2)
class QuestEncoder_2(nn.Module):
"""
Stack encoder
"""
def __init__(self, opts, textVocabSize):
super(QuestEncoder_2, self).__init__()
bidirectional = opts.bidirectional > 0
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
self.lstmQ = nn.LSTM(
input_size=opts.embedDim,
hidden_size=opts.hiddenDim,
num_layers=opts.numLayers,
batch_first=True,
bidirectional=bidirectional,
)
self.lstmH = nn.LSTM(
input_size=opts.embedDim,
hidden_size=opts.hiddenDim,
num_layers=opts.numLayers,
batch_first=True,
bidirectional=bidirectional,
)
if bidirectional:
opts.hiddenDim *= 2
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
def forward(self, quest, hist):
quest = F.tanh(self.embedding(quest))
quest, (questH, _) = self.lstmQ(quest)
# concatenate the last hidden states from the forward and backward pass
# of the bidirectional lstm
lastHiddenForward = questH[1:2, :, :].squeeze(0)
lastHiddenBackward = questH[3:4, :, :].squeeze(0)
# questH: (batchSize, 512)
questH = torch.cat([lastHiddenForward, lastHiddenBackward], -1)
questO = quest.detach().clone()
hist = F.tanh(self.embedding(hist))
numRounds = hist.size(1)
histFeat = []
for i in range(numRounds):
round_i = hist[:, i, :, :]
_, (round_i_h, _) = self.lstmH(round_i)
#Same as before
lastHiddenForward = round_i_h[1:2, :, :].squeeze(0)
lastHiddenBackward = round_i_h[3:4, :, :].squeeze(0)
histFeat.append(torch.cat([lastHiddenForward, lastHiddenBackward], -1))
# hist: (batchSize, rounds, 512)
histFeat = torch.stack(histFeat, 1)
attWeights = torch.sum(torch.mul(histFeat, questH.unsqueeze(1)), -1)
attWeights = torch.softmax(attWeights, -1)
histFeat = torch.sum(torch.mul(histFeat, attWeights.unsqueeze(2)), 1)
encOut = self.fc(torch.cat([questH, histFeat], -1))
return encOut, questO
class Decoder(nn.Module):
def __init__(self, opts, progVocabSize, maxLen, startID=1, endID=2):
super(Decoder, self).__init__()
self.numLayers = opts.numLayers
self.bidirectional = opts.bidirectional > 0
self.maxLen = maxLen
self.startID = startID
self.endID = endID
self.embedding = nn.Embedding(progVocabSize, opts.embedDim)
self.lstmProg = nn.LSTM(
input_size=opts.embedDim,
hidden_size=2*opts.hiddenDim if self.bidirectional else opts.hiddenDim,
num_layers=opts.numLayers,
batch_first=True,
# bidirectional=self.bidirectional,
)
hiddenDim = opts.hiddenDim
if self.bidirectional:
hiddenDim *= 2
self.fcAtt = nn.Linear(2*hiddenDim, hiddenDim)
self.fcOut = nn.Linear(hiddenDim, progVocabSize)
def initPrgHidden(self, encOut):
hidden = [encOut for _ in range(self.numLayers)]
hidden = torch.stack(hidden, 0).contiguous()
return hidden, hidden
def forwardStep(self, prog, progH, questO):
batchSize = prog.size(0)
inputDim = questO.size(1)
prog = self.embedding(prog)
outProg, progH = self.lstmProg(prog, progH)
att = torch.bmm(outProg, questO.transpose(1, 2))
att = F.softmax(att.view(-1, inputDim), 1).view(batchSize, -1, inputDim)
context = torch.bmm(att, questO)
# (batchSize, progLength, hiddenDim)
out = F.tanh(self.fcAtt(torch.cat([outProg, context], dim=-1)))
# (batchSize, progLength, progVocabSize)
out = self.fcOut(out)
predSoftmax = F.log_softmax(out, 2)
return predSoftmax, progH
def forward(self, prog, encOut, questO):
progH = self.initPrgHidden(encOut)
predSoftmax, progH = self.forwardStep(prog, progH, questO)
return predSoftmax, progH
def sample(self, encOut, questO):
batchSize = encOut.size(0)
cudaFlag = encOut.is_cuda
progH = self.initPrgHidden(encOut)
# prog = progCopy[:, 0:3]
prog = torch.LongTensor(batchSize, 1).fill_(self.startID)
# prog = torch.cat((progStart, progEnd), -1)
if cudaFlag:
prog = prog.cuda()
outputLogProbs = []
outputTokens = []
def decode(i, output):
tokens = output.topk(1, dim=-1)[1].view(batchSize, -1)
return tokens
for i in range(self.maxLen):
predSoftmax, progH = self.forwardStep(prog, progH, questO)
prog = decode(i, predSoftmax)
return outputTokens, outputLogProbs
class SeqToSeqC(nn.Module):
def __init__(self, encoder, decoder):
super(SeqToSeqC, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, cap, imgFeat, prog):
encOut, capO = self.encoder(cap, imgFeat)
predSoftmax, progHC = self.decoder(prog, encOut, capO)
return predSoftmax, progHC
class SeqToSeqQ(nn.Module):
def __init__(self, encoder, decoder):
super(SeqToSeqQ, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, quest, hist, prog):
encOut, questO = self.encoder(quest, hist)
predSoftmax, progHC = self.decoder(prog, encOut, questO)
return predSoftmax, progHC
def sample(self, quest, hist):
with torch.no_grad():
encOut, questO = self.encoder(quest, hist)
outputTokens, outputLogProbs = self.decoder.sample(encOut, questO)
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
return outputTokens