Update 'prog_generator/models.py'
This commit is contained in:
parent
6de837ea53
commit
ce9c902570
1 changed files with 6 additions and 0 deletions
|
@ -456,6 +456,12 @@ class SeqToSeqC(nn.Module):
|
||||||
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
||||||
return predSoftmax, progHC
|
return predSoftmax, progHC
|
||||||
|
|
||||||
|
def sample(self, cap):
|
||||||
|
with torch.no_grad():
|
||||||
|
encOut, capO = self.encoder(cap)
|
||||||
|
outputTokens, outputLogProbs = self.decoder.sample(encOut, capO)
|
||||||
|
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
|
||||||
|
return outputTokens
|
||||||
|
|
||||||
class SeqToSeqQ(nn.Module):
|
class SeqToSeqQ(nn.Module):
|
||||||
def __init__(self, encoder, decoder):
|
def __init__(self, encoder, decoder):
|
||||||
|
|
Loading…
Reference in a new issue