OLViT/src/models/generative_model.py

351 lines
16 KiB
Python

# code is partly inspired from https://pytorch.org/tutorials/beginner/translation_transformer.html
from unittest import result
from src.models.state_tracker_model import StateTrackerModel
from src.utils.batch_interfaces import batch_interface_simmc2_to_dvd, batch_interface_avsd_to_dvd
from dataclasses import dataclass
import torch
from torch import nn
from torchtext.data.metrics import bleu_score
import json
import os
from transformers import AutoTokenizer
import nltk
import numpy as np
from src.utils.text_utils import normalize_sentence, translate_from_ids_to_text
class GenerativeModel(StateTrackerModel):
def __init__(self, config, output_path=None):
super().__init__(config, output_path=output_path)
self.transformer = nn.Transformer(
d_model=self.model_input_dim,
batch_first=True,
dropout=self.config['dropout_p'],
dim_feedforward=self.config['dim_feedforward'],
nhead=self.config['n_heads'],
num_encoder_layers=self.config['n_encoder_layers'],
num_decoder_layers=self.config['n_decoder_layers'],
custom_encoder=self.encoder
)
self.prob_generator = nn.Linear(self.model_input_dim, self.config['vocab_size'])
self.pad_id = 1
self.unk_id = 3
self.loss = nn.CrossEntropyLoss(ignore_index=self.pad_id)
# tokenizer for translation from ids to text
self.tokenizer = AutoTokenizer.from_pretrained(self.config['pretrained_lm_name'])
# ---TODO: Remove ------
self.results = {}
self.epoch_count = 0
# -----------------------
self.batch_interface = batch_interface_simmc2_to_dvd
def encode_object_descriptions(self, vft):
#embed the object descriptions using bert and then create the object token using transformer layers
if self.config['feature_type'] == "object_text_features":
object_features = []
for i in range(vft.shape[1]):
object_description = vft[:, i, :]
object_description_mask = (object_description != 1)
embedded_object_description = self.apply_pretrained_lm(object_description, object_description_mask)
#map embeddings to a smaller size (motivation: reduce transformer sice of object description encoder)
embedded_object_description = self.linear_projection_object_description(embedded_object_description)
#apply transformer to encode the object description
object_token = self.object_description_encoder(embedded_object_description)
object_features.append(object_token)
object_features = torch.concat(object_features, dim=1)
#add frame dimension (only one frame in this cas)
object_features = object_features.unsqueeze(1)
#bring the data to the format [batch_size x frames x emb_dim (desc_text_len) x obj_number]
vft = object_features.permute(0, 1, 3, 2)
return vft
def create_target_mask(self, size):
mask = torch.triu(torch.ones((size,size), device=self.device), 1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def generate_prob_for_next_tokens(self, input, answer_emb, tgt_mask, input_mask, answer_mask):
x = self.transformer.encoder(input, src_key_padding_mask=input_mask)
dec_out = self.transformer.decoder(answer_emb, x, tgt_mask)
probs = self.prob_generator(dec_out)
return probs
def generate_complete_answers(self, input, input_mask):
# encode the complete batch of questions
memory = self.transformer.encoder(input, src_key_padding_mask=input_mask)
generated_answers = torch.ones(memory.shape[0], 40, dtype=torch.int) # 20 = max answer length, use unknown token ()
# generate the answers for each individual question from the batch
for i in range(memory.shape[0]):
memory_i = memory[i, :, :]
memory_i = memory_i.unsqueeze(0)
answer_i = torch.zeros((1,1), dtype=torch.int, device=self.device) # Pass start token <s> to decoder as first input. From roberta vocab: <s>": 0, "</s>": 2
for j in range(40): # 20 = max answer length
answer_i_emb = self.prepare_lang_emb(answer_i, torch.ones((1, answer_i.shape[0]), device=self.device, dtype=torch.int16))
tgt_mask = self.create_target_mask(answer_i.shape[1])
decoder_output = self.transformer.decoder(answer_i_emb, memory_i, tgt_mask)
prob = self.prob_generator(decoder_output[:, -1, :])
next_word = prob.argmax()
answer_i = torch.concat([answer_i, next_word.unsqueeze(0).unsqueeze(0)], dim=1)
if next_word.item() == 2: # eos token in roberta vocab "</s>": 2
break
generated_answers[i, :answer_i.shape[1] - 1] = answer_i[0, 1:]
return generated_answers
def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=False):
# combine state and embeddings
input = self.combiner(
video_emb,
language_emb,
language_emb_mask,
v_state,
d_state
)
# create input mask based on the language_emb_mask (complete video is unmasked)
input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device)
offset = 0
if v_state is not None: offset += 1
if d_state is not None: offset += 1
# offset is caused by state vectors
input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask
tgt_mask = self.create_target_mask(answer_emb.shape[1])
#-------TODO: Mask padded object embeddings when text based object embeddings are used -------------
if self.mode == 'train' or state_generation_mode:
probs = self.generate_prob_for_next_tokens(input, answer_emb, tgt_mask, input_mask, answer_mask)
return probs
elif self.mode == 'val':
generated_answers = self.generate_complete_answers(input, input_mask)
return generated_answers
def prepare_answer_emb_and_mask(self, answer, answer_mask):
mask = torch.tril(torch.ones((answer.shape[1], answer.shape[1]), device=self.device))
mask = mask.unsqueeze(0)
mask = mask.expand(answer.shape[0], -1, -1)
answer_emb = self.apply_pretrained_lm(answer, mask)
answer_emb = self.linear_projection_text(answer_emb)
answer_emb = self.append_ids(answer_emb, [1, 0], 2)
answer_emb = self.positional_encoder(answer_emb)
# pytorch interprets True in a mask as padding
answer_mask = ~answer_mask
answer_emb_final = answer_emb[:, :-1].detach()
answer_mask_final = answer_mask[:, :-1].detach()
return answer_emb_final, answer_mask_final
def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False):
video_emb = self.prepare_video_emb(vft)
lang_emb = self.prepare_lang_emb(query, query_mask)
answer_emb, answer_mask = self.prepare_answer_emb_and_mask(answer, answer_mask)
output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode)
return output
def training_step(self, train_batch, batch_idx):
train_batch = self.batch_interface(train_batch, feature_type=self.config['feature_type'])
if self.config['feature_type'] == "object_text_features":
train_batch.vft = self.encode_object_descriptions(train_batch.vft)
logits = self.forward(train_batch)
logits = logits.permute(0, 2, 1)
# replace any unknown token (id = 3) with a padding token in order to also ignore them -> avoid model which outputs unk tokens
train_batch.answer[train_batch.answer == 3] = 1
loss = self.loss(logits, train_batch.answer[:, 1:]) # ignore padding
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
return loss
def get_next_token_pred_as_text_and_logits(self, batch):
# set mode to train to get the logits instead of completely generated sentences
self.mode = 'train'
logits = self.forward(batch)
logits = logits.permute(0, 2, 1)
predicted_tokens = []
for j in range(logits.shape[0]):
l = logits[j, :, :]
ids = [l[:, i].argmax().item() for i in range(l.shape[1])]
text = translate_from_ids_to_text(ids, self.tokenizer)
predicted_tokens.append(text)
# set mode back to val
self.mode = 'val'
return predicted_tokens, logits
def calculate_bleu_score(self, generated_answer_ids, correct_answer):
# calculate bleu score for the generated answers compared to the provided correct answers
bleu4_scores = []
all_generated_answers = []
for i in range(generated_answer_ids.shape[0]):
generated_answer = generated_answer_ids[i, :].tolist()
generated_answer_text = translate_from_ids_to_text(generated_answer, self.tokenizer)
all_generated_answers.append(generated_answer_text)
correct_answer_text_i = correct_answer[i]
score4 = nltk.translate.bleu_score.sentence_bleu(
[normalize_sentence(correct_answer_text_i)],
normalize_sentence(generated_answer_text),
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method7
)
bleu4_scores.append(score4)
bleu4_score = np.mean(bleu4_scores)
return bleu4_score, all_generated_answers
def translate_answer_ids_to_text(self, answer):
correct_answer_text = []
for i in range(answer.shape[0]):
correct_answer_i = answer[i, :].tolist()
correct_answer_text_i = translate_from_ids_to_text(correct_answer_i, self.tokenizer)
correct_answer_text.append(correct_answer_text_i)
return correct_answer_text
def validation_step(self, val_batch, batch_idx):
val_batch = self.batch_interface(val_batch, feature_type=self.config['feature_type'])
if self.config['feature_type'] == "object_text_features":
val_batch.vft = self.encode_object_descriptions(val_batch.vft)
correct_answer_text = self.translate_answer_ids_to_text(val_batch.answer)
generated_answer_ids = self.forward(val_batch)
# calculate and log bleu score for the generated answers compared to the provided correct answers
bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text)
self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0])
# calculate and log the validation loss based on the results from next token predicition (train mode needed)
predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(val_batch)
loss = self.loss(logits, val_batch.answer[:, 1:]) # ignore padding
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
return {'next_token_predictions': predicted_tokens, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text}
def test_step(self, test_batch, batch_idx):
dialog_id = test_batch['dialog_id']
turn_id = test_batch['turn_id']
test_batch = self.batch_interface(test_batch, feature_type=self.config['feature_type'])
if self.config['feature_type'] == "object_text_features":
test_batch.vft = self.encode_object_descriptions(test_batch.vft)
correct_answer_text = self.translate_answer_ids_to_text(test_batch.answer)
generated_answer_ids = self.forward(test_batch)
# calculate and log bleu score for the generated answers compared to the provided correct answers
bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text)
self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0])
# calculate and log the validation loss based on the results from next token predicition (train mode needed)
predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(test_batch)
loss = self.loss(logits, test_batch.answer[:, 1:]) # ignore padding
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
return {'turn_id': turn_id, 'next_token_predictions': predicted_tokens, 'dialog_id': dialog_id, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text}
def test_epoch_end(self, outputs):
if self.config['output_format'] == 'submission':
responses = []
for output in outputs:
for t_id, d_id, answer in zip(output['turn_id'], output['dialog_id'], output['generated_answers']):
sample = {
'dialog_id': d_id,
'predictions': [
{
'turn_id': t_id,
'response': answer
}
]
}
responses.append(sample)
name = 'dstc11-simmc-devtest-pred-subtask-4-generation.json'
with open(os.path.join(self.output_path, name), 'w') as file:
json.dump(responses, file)
else:
result_idx = 0
for output in outputs:
for j in range(len(output['next_token_predictions'])):
pred = " "
corr = " "
gen = " "
self.results[result_idx] = {
'next_token_pred': pred.join(output['next_token_predictions'][j]),
'generated_ans': gen.join(output['generated_answers'][j]),
'correct': corr.join(output['correct_answers'][j])
}
result_idx += 1
name = f'epoch_{self.epoch_count}.json'
with open(os.path.join(self.output_path, name), 'w') as file:
json.dump(self.results, file)
def validation_epoch_end(self, outputs):
result_idx = 0
for output in outputs:
for j in range(len(output['next_token_predictions'])):
pred = " "
corr = " "
gen = " "
self.results[result_idx] = {
'next_token_pred': pred.join(output['next_token_predictions'][j]),
'generated_ans': gen.join(output['generated_answers'][j]),
'correct': corr.join(output['correct_answers'][j])
}
result_idx += 1
name = f'epoch_{self.epoch_count}.json'
with open(os.path.join(self.output_path, name), 'w') as file:
json.dump(self.results, file)
self.results = {}
self.epoch_count += 1
def on_train_epoch_start(self):
self.mode = 'train'
def on_validation_epoch_start(self):
self.mode = 'val'
def on_test_epoch_start(self):
self.mode = 'val'