Update 'prog_generator/models.py'
This commit is contained in:
parent
6de837ea53
commit
ce9c902570
|
@ -456,6 +456,12 @@ class SeqToSeqC(nn.Module):
|
|||
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
||||
return predSoftmax, progHC
|
||||
|
||||
def sample(self, cap):
|
||||
with torch.no_grad():
|
||||
encOut, capO = self.encoder(cap)
|
||||
outputTokens, outputLogProbs = self.decoder.sample(encOut, capO)
|
||||
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
|
||||
return outputTokens
|
||||
|
||||
class SeqToSeqQ(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
|
|
Loading…
Reference in a new issue