make code public
This commit is contained in:
commit
9d8b93db26
26 changed files with 11937 additions and 0 deletions
476
prog_generator/models.py
Normal file
476
prog_generator/models.py
Normal file
|
@ -0,0 +1,476 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
|
||||
super(FC, self).__init__()
|
||||
self.dropout_r = dropout_r
|
||||
self.use_relu = use_relu
|
||||
|
||||
self.linear = nn.Linear(in_size, out_size)
|
||||
|
||||
if use_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if dropout_r > 0:
|
||||
self.dropout = nn.Dropout(dropout_r)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
||||
if self.use_relu:
|
||||
x = self.relu(x)
|
||||
|
||||
if self.dropout_r > 0:
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
||||
self.linear = nn.Linear(mid_size, out_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.fc(x))
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, size, eps=1e-6):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
self.a_2 = nn.Parameter(torch.ones(size))
|
||||
self.b_2 = nn.Parameter(torch.zeros(size))
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = x.std(-1, keepdim=True)
|
||||
|
||||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||
|
||||
|
||||
class MHAtt(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(MHAtt, self).__init__()
|
||||
self.opts = opts
|
||||
|
||||
self.linear_v = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_k = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_q = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
self.linear_merge = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
self.dropout = nn.Dropout(opts.dropout)
|
||||
|
||||
def forward(self, v, k, q, mask):
|
||||
n_batches = q.size(0)
|
||||
|
||||
v = self.linear_v(v).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
k = self.linear_k(k).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
q = self.linear_q(q).view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.multiHead,
|
||||
self.opts.hiddenSizeHead
|
||||
).transpose(1, 2)
|
||||
|
||||
atted = self.att(v, k, q, mask)
|
||||
atted = atted.transpose(1, 2).contiguous().view(
|
||||
n_batches,
|
||||
-1,
|
||||
self.opts.hiddenDim
|
||||
)
|
||||
|
||||
atted = self.linear_merge(atted)
|
||||
|
||||
return atted
|
||||
|
||||
def att(self, value, key, query, mask):
|
||||
d_k = query.size(-1)
|
||||
|
||||
scores = torch.matmul(
|
||||
query, key.transpose(-2, -1)
|
||||
) / math.sqrt(d_k)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask, -1e9)
|
||||
|
||||
att_map = F.softmax(scores, dim=-1)
|
||||
att_map = self.dropout(att_map)
|
||||
|
||||
return torch.matmul(att_map, value)
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(FFN, self).__init__()
|
||||
|
||||
self.mlp = MLP(
|
||||
in_size=opts.hiddenDim,
|
||||
mid_size=opts.FeedForwardSize,
|
||||
out_size=opts.hiddenDim,
|
||||
dropout_r=opts.dropout,
|
||||
use_relu=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class SA(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(SA, self).__init__()
|
||||
self.mhatt = MHAtt(opts)
|
||||
self.ffn = FFN(opts)
|
||||
|
||||
self.dropout1 = nn.Dropout(opts.dropout)
|
||||
self.norm1 = LayerNorm(opts.hiddenDim)
|
||||
|
||||
self.dropout2 = nn.Dropout(opts.dropout)
|
||||
self.norm2 = LayerNorm(opts.hiddenDim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.norm1(x + self.dropout1(
|
||||
self.mhatt(x, x, x, x_mask)
|
||||
))
|
||||
|
||||
x = self.norm2(x + self.dropout2(
|
||||
self.ffn(x)
|
||||
))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AttFlat(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(AttFlat, self).__init__()
|
||||
self.opts = opts
|
||||
|
||||
self.mlp = MLP(
|
||||
in_size=opts.hiddenDim,
|
||||
mid_size=opts.FlatMLPSize,
|
||||
out_size=opts.FlatGlimpses,
|
||||
dropout_r=opts.dropout,
|
||||
use_relu=True
|
||||
)
|
||||
# FLAT_GLIMPSES = 1
|
||||
self.linear_merge = nn.Linear(
|
||||
opts.hiddenDim * opts.FlatGlimpses,
|
||||
opts.FlatOutSize
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
att = self.mlp(x)
|
||||
att = att.masked_fill(
|
||||
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
||||
-1e9
|
||||
)
|
||||
att = F.softmax(att, dim=1)
|
||||
|
||||
att_list = []
|
||||
for i in range(self.opts.FlatGlimpses):
|
||||
att_list.append(
|
||||
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
||||
)
|
||||
|
||||
x_atted = torch.cat(att_list, dim=1)
|
||||
x_atted = self.linear_merge(x_atted)
|
||||
|
||||
return x_atted
|
||||
|
||||
class CaptionEncoder(nn.Module):
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(CaptionEncoder, self).__init__()
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
bidirectional = opts.bidirectional > 0
|
||||
self.lstmC = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional
|
||||
)
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
opts.hiddenSizeHead *= 2
|
||||
opts.FlatOutSize *= 2
|
||||
|
||||
self.attCap = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
self.attFlatCap = AttFlat(opts)
|
||||
self.fc = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, cap, hist=None):
|
||||
capMask = self.make_mask(cap.unsqueeze(2))
|
||||
cap = self.embedding(cap)
|
||||
cap, (_, _) = self.lstmC(cap)
|
||||
capO = cap.detach().clone()
|
||||
|
||||
for attC in self.attCap:
|
||||
cap = attC(cap, capMask)
|
||||
# (batchSize, 512)
|
||||
cap = self.attFlatCap(cap, capMask)
|
||||
encOut = self.fc(cap)
|
||||
return encOut, capO
|
||||
|
||||
class QuestEncoder_1(nn.Module):
|
||||
"""
|
||||
Concat encoder
|
||||
"""
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(QuestEncoder_1, self).__init__()
|
||||
bidirectional = opts.bidirectional > 0
|
||||
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
self.lstmQ = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
self.lstmH = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=True)
|
||||
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
opts.hiddenSizeHead *= 2
|
||||
opts.FlatOutSize *= 2
|
||||
self.attQues = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
self.attHist = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||
|
||||
self.attFlatQuest = AttFlat(opts)
|
||||
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, quest, hist):
|
||||
questMask = self.make_mask(quest.unsqueeze(2))
|
||||
histMask = self.make_mask(hist.unsqueeze(2))
|
||||
|
||||
# quest = F.tanh(self.embedding(quest))
|
||||
quest = self.embedding(quest)
|
||||
|
||||
quest, (_, _) = self.lstmQ(quest)
|
||||
questO = quest.detach().clone()
|
||||
|
||||
hist = self.embedding(hist)
|
||||
hist, (_, _) = self.lstmH(hist)
|
||||
|
||||
for attQ, attH in zip(self.attQues, self.attHist):
|
||||
quest = attQ(quest, questMask)
|
||||
hist = attH(hist, histMask)
|
||||
# (batchSize, 512)
|
||||
quest = self.attFlatQuest(quest, questMask)
|
||||
|
||||
# hist: (batchSize, length, 512)
|
||||
attWeights = torch.sum(torch.mul(hist, quest.unsqueeze(1)), -1)
|
||||
attWeights = torch.softmax(attWeights, -1)
|
||||
hist = torch.sum(torch.mul(hist, attWeights.unsqueeze(2)), 1)
|
||||
encOut = self.fc(torch.cat([quest, hist], -1))
|
||||
|
||||
return encOut, questO
|
||||
|
||||
# Masking
|
||||
def make_mask(self, feature):
|
||||
return (torch.sum(
|
||||
torch.abs(feature),
|
||||
dim=-1
|
||||
) == 0).unsqueeze(1).unsqueeze(2)
|
||||
|
||||
|
||||
class QuestEncoder_2(nn.Module):
|
||||
"""
|
||||
Stack encoder
|
||||
"""
|
||||
def __init__(self, opts, textVocabSize):
|
||||
super(QuestEncoder_2, self).__init__()
|
||||
bidirectional = opts.bidirectional > 0
|
||||
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||
self.lstmQ = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
self.lstmH = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
if bidirectional:
|
||||
opts.hiddenDim *= 2
|
||||
|
||||
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||
|
||||
def forward(self, quest, hist):
|
||||
|
||||
quest = F.tanh(self.embedding(quest))
|
||||
quest, (questH, _) = self.lstmQ(quest)
|
||||
|
||||
# concatenate the last hidden states from the forward and backward pass
|
||||
# of the bidirectional lstm
|
||||
lastHiddenForward = questH[1:2, :, :].squeeze(0)
|
||||
lastHiddenBackward = questH[3:4, :, :].squeeze(0)
|
||||
|
||||
# questH: (batchSize, 512)
|
||||
questH = torch.cat([lastHiddenForward, lastHiddenBackward], -1)
|
||||
|
||||
questO = quest.detach().clone()
|
||||
|
||||
hist = F.tanh(self.embedding(hist))
|
||||
numRounds = hist.size(1)
|
||||
histFeat = []
|
||||
for i in range(numRounds):
|
||||
round_i = hist[:, i, :, :]
|
||||
_, (round_i_h, _) = self.lstmH(round_i)
|
||||
|
||||
#Same as before
|
||||
lastHiddenForward = round_i_h[1:2, :, :].squeeze(0)
|
||||
lastHiddenBackward = round_i_h[3:4, :, :].squeeze(0)
|
||||
histFeat.append(torch.cat([lastHiddenForward, lastHiddenBackward], -1))
|
||||
|
||||
# hist: (batchSize, rounds, 512)
|
||||
histFeat = torch.stack(histFeat, 1)
|
||||
attWeights = torch.sum(torch.mul(histFeat, questH.unsqueeze(1)), -1)
|
||||
attWeights = torch.softmax(attWeights, -1)
|
||||
histFeat = torch.sum(torch.mul(histFeat, attWeights.unsqueeze(2)), 1)
|
||||
encOut = self.fc(torch.cat([questH, histFeat], -1))
|
||||
return encOut, questO
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, opts, progVocabSize, maxLen, startID=1, endID=2):
|
||||
super(Decoder, self).__init__()
|
||||
self.numLayers = opts.numLayers
|
||||
self.bidirectional = opts.bidirectional > 0
|
||||
self.maxLen = maxLen
|
||||
self.startID = startID
|
||||
self.endID = endID
|
||||
|
||||
self.embedding = nn.Embedding(progVocabSize, opts.embedDim)
|
||||
self.lstmProg = nn.LSTM(
|
||||
input_size=opts.embedDim,
|
||||
hidden_size=2*opts.hiddenDim if self.bidirectional else opts.hiddenDim,
|
||||
num_layers=opts.numLayers,
|
||||
batch_first=True,
|
||||
# bidirectional=self.bidirectional,
|
||||
)
|
||||
hiddenDim = opts.hiddenDim
|
||||
if self.bidirectional:
|
||||
hiddenDim *= 2
|
||||
|
||||
self.fcAtt = nn.Linear(2*hiddenDim, hiddenDim)
|
||||
self.fcOut = nn.Linear(hiddenDim, progVocabSize)
|
||||
|
||||
def initPrgHidden(self, encOut):
|
||||
hidden = [encOut for _ in range(self.numLayers)]
|
||||
hidden = torch.stack(hidden, 0).contiguous()
|
||||
return hidden, hidden
|
||||
|
||||
def forwardStep(self, prog, progH, questO):
|
||||
batchSize = prog.size(0)
|
||||
inputDim = questO.size(1)
|
||||
prog = self.embedding(prog)
|
||||
outProg, progH = self.lstmProg(prog, progH)
|
||||
|
||||
att = torch.bmm(outProg, questO.transpose(1, 2))
|
||||
att = F.softmax(att.view(-1, inputDim), 1).view(batchSize, -1, inputDim)
|
||||
context = torch.bmm(att, questO)
|
||||
# (batchSize, progLength, hiddenDim)
|
||||
out = F.tanh(self.fcAtt(torch.cat([outProg, context], dim=-1)))
|
||||
|
||||
# (batchSize, progLength, progVocabSize)
|
||||
out = self.fcOut(out)
|
||||
predSoftmax = F.log_softmax(out, 2)
|
||||
return predSoftmax, progH
|
||||
|
||||
def forward(self, prog, encOut, questO):
|
||||
progH = self.initPrgHidden(encOut)
|
||||
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||
|
||||
return predSoftmax, progH
|
||||
|
||||
def sample(self, encOut, questO):
|
||||
batchSize = encOut.size(0)
|
||||
cudaFlag = encOut.is_cuda
|
||||
progH = self.initPrgHidden(encOut)
|
||||
# prog = progCopy[:, 0:3]
|
||||
prog = torch.LongTensor(batchSize, 1).fill_(self.startID)
|
||||
# prog = torch.cat((progStart, progEnd), -1)
|
||||
if cudaFlag:
|
||||
prog = prog.cuda()
|
||||
outputLogProbs = []
|
||||
outputTokens = []
|
||||
|
||||
def decode(i, output):
|
||||
tokens = output.topk(1, dim=-1)[1].view(batchSize, -1)
|
||||
return tokens
|
||||
|
||||
for i in range(self.maxLen):
|
||||
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||
prog = decode(i, predSoftmax)
|
||||
|
||||
return outputTokens, outputLogProbs
|
||||
|
||||
|
||||
class SeqToSeqC(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(SeqToSeqC, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, cap, imgFeat, prog):
|
||||
encOut, capO = self.encoder(cap, imgFeat)
|
||||
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
||||
return predSoftmax, progHC
|
||||
|
||||
|
||||
class SeqToSeqQ(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(SeqToSeqQ, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, quest, hist, prog):
|
||||
encOut, questO = self.encoder(quest, hist)
|
||||
predSoftmax, progHC = self.decoder(prog, encOut, questO)
|
||||
return predSoftmax, progHC
|
||||
|
||||
def sample(self, quest, hist):
|
||||
with torch.no_grad():
|
||||
encOut, questO = self.encoder(quest, hist)
|
||||
outputTokens, outputLogProbs = self.decoder.sample(encOut, questO)
|
||||
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
|
||||
return outputTokens
|
Loading…
Add table
Add a link
Reference in a new issue