From ce9c902570b9d4aa1bb80389a5c5f1e84029880d Mon Sep 17 00:00:00 2001 From: Adnen Abdessaied Date: Mon, 20 Mar 2023 16:07:24 +0100 Subject: [PATCH] Update 'prog_generator/models.py' --- prog_generator/models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/prog_generator/models.py b/prog_generator/models.py index da1f037..3323933 100644 --- a/prog_generator/models.py +++ b/prog_generator/models.py @@ -456,6 +456,12 @@ class SeqToSeqC(nn.Module): predSoftmax, progHC = self.decoder(prog, encOut, capO) return predSoftmax, progHC + def sample(self, cap): + with torch.no_grad(): + encOut, capO = self.encoder(cap) + outputTokens, outputLogProbs = self.decoder.sample(encOut, capO) + outputTokens = torch.stack(outputTokens, 0).transpose(0, 1) + return outputTokens class SeqToSeqQ(nn.Module): def __init__(self, encoder, decoder):