diff --git a/prog_generator/models.py b/prog_generator/models.py index da1f037..3323933 100644 --- a/prog_generator/models.py +++ b/prog_generator/models.py @@ -456,6 +456,12 @@ class SeqToSeqC(nn.Module): predSoftmax, progHC = self.decoder(prog, encOut, capO) return predSoftmax, progHC + def sample(self, cap): + with torch.no_grad(): + encOut, capO = self.encoder(cap) + outputTokens, outputLogProbs = self.decoder.sample(encOut, capO) + outputTokens = torch.stack(outputTokens, 0).transpose(0, 1) + return outputTokens class SeqToSeqQ(nn.Module): def __init__(self, encoder, decoder):