Update 'prog_generator/models.py'

This commit is contained in:
Adnen Abdessaied 2023-03-20 16:07:24 +01:00
parent 6de837ea53
commit ce9c902570

View file

@ -456,6 +456,12 @@ class SeqToSeqC(nn.Module):
predSoftmax, progHC = self.decoder(prog, encOut, capO) predSoftmax, progHC = self.decoder(prog, encOut, capO)
return predSoftmax, progHC return predSoftmax, progHC
def sample(self, cap):
with torch.no_grad():
encOut, capO = self.encoder(cap)
outputTokens, outputLogProbs = self.decoder.sample(encOut, capO)
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
return outputTokens
class SeqToSeqQ(nn.Module): class SeqToSeqQ(nn.Module):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):