OLViT/src/models/discriminative_model.py

138 lines
5.9 KiB
Python

from src.models.state_tracker_model import StateTrackerModel
import torch
from torch import nn
from src.utils.text_utils import translate_from_ids_to_text
import pandas as pd
class DiscriminativeModel(StateTrackerModel):
def __init__(self, config, output_path=None):
super().__init__(config, output_path=output_path)
self.fc = nn.Linear(self.model_input_dim, self.config["fc_dim"])
self.relu = nn.ReLU()
self.output = nn.Linear(self.config["fc_dim"], 40)
def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=None):
# analogous to the CLS token from BERT models
dummy_word = torch.zeros(self.model_input_dim, requires_grad=True, device=self.device)
dummy_word = torch.tile(dummy_word, (language_emb.shape[0], 1, 1))
# combine state and embeddings
input = self.combiner(
video_emb,
language_emb,
language_emb_mask,
v_state,
d_state,
dummy_word
)
# create input mask based on the language_emb_mask (complete video is unmasked)
input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device)
offset = 1
if v_state is not None: offset += 1
if d_state is not None: offset += 1
# offset is caused by cls token and state vectors
if self.config['model_type'] == 'extended_model':
# set offset to 1 if combiner B is used -> no state vectors as input. Instead concatenated with embeddings
if self.ext_config['combiner_option'] == 'OptionB':
offset = 1
input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask
x = self.encoder(input, src_key_padding_mask=input_mask)
# only pass transformed dummy word to the dense layers
x = self.fc(x[:, 0, :])
x = self.relu(x)
output = self.output(x)
return output
def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False):
video_emb = self.prepare_video_emb(vft)
lang_emb = self.prepare_lang_emb(query, query_mask)
if answer is not None and answer_mask is not None:
answer_emb = self.prepare_lang_emb(answer, answer_mask)
else:
answer_emb = None
output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode)
return output
def training_step(self, train_batch, batch_idx):
train_batch.move_to_cuda()
label = torch.squeeze(train_batch.answer)
out = self.forward(train_batch)
loss = self.loss(out, label)
tr_acc = self.train_acc(out.softmax(dim=1), label)
if tr_acc > self.best_train_acc:
self.best_train_acc = tr_acc
self.log("train_acc", tr_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
print('train_loss: {} | train_acc: {}'.format(loss, tr_acc))
return loss
def validation_step(self, val_batch, batch_idx):
val_batch.move_to_cuda()
label = torch.squeeze(val_batch.answer)
out = self.forward(val_batch)
loss = self.loss(out, label)
self.val_acc(out.softmax(dim=1), label)
self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
return {'val_loss': loss, 'val_acc': self.val_acc.compute()}
def test_step(self, test_batch, batch_idx):
test_batch.move_to_cuda()
label = torch.squeeze(test_batch.answer)
out = self.forward(test_batch)
loss = self.loss(out, label)
self.test_acc(out.softmax(dim=1), label)
self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
# save the results into a dictionary
out = torch.argmax(out, dim=1)
question_as_text = []
for i in range(test_batch.query.shape[0]):
question_ids = test_batch.query[i, :]
question_as_text.append(translate_from_ids_to_text(question_ids, self.tokenizer))
self.results['question'].extend(question_as_text)
self.results['video_name'].extend(test_batch.video_name)
self.results['qa_id'].extend(test_batch.qa_ids)
self.results['q_type'].extend(test_batch.q_type)
self.results['label'].extend(label.tolist())
self.results['output'].extend(out.tolist())
self.results['attribute_dependency'].extend(test_batch.attribute_dependency)
self.results['object_dependency'].extend(test_batch.object_dependency)
self.results['temporal_dependency'].extend(test_batch.temporal_dependency)
self.results['spatial_dependency'].extend(test_batch.spatial_dependency)
self.results['q_complexity'].extend(test_batch.q_complexity)
def on_test_start(self):
self.results = {
'qa_id': [],
'q_type': [],
'label': [],
'output': [],
'attribute_dependency': [],
'object_dependency': [],
'temporal_dependency': [],
'spatial_dependency': [],
'q_complexity': [],
# only needed for input output analysis
'question': [],
'video_name': []
}
def on_test_end(self):
df = pd.DataFrame.from_dict(self.results)
df.to_pickle(self.output_path)