neuro-symbolic-visual-dialog/prog_generator/clevrDialog_dataset.py
2022-08-10 16:49:55 +02:00

94 lines
3.7 KiB
Python

"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
import h5py
import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset
def invertDict(_dict):
return {v: k for k, v in _dict.items()}
class ClevrDialogDataset(Dataset):
def __init__(self, dataPath, vocabPath, split, indStart=0, indEnd=-1):
super(ClevrDialogDataset, self).__init__()
self.data = h5py.File(dataPath, "r")
with open(vocabPath, "r") as f:
self.vocab = json.load(f)
self.vocab["idx_text_to_token"] = invertDict(self.vocab["text_token_to_idx"])
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
self.lenVocabText = len(self.vocab["text_token_to_idx"])
self.lenVocabProg = len(self.vocab["prog_token_to_idx"])
self.split = split
self.indStart = indStart
self.indEnd = indEnd
self.maxSamples = indEnd - indStart
self.maxLenProg = 6
def __len__(self):
raise NotImplementedError
def __getitem__(self, index):
raise NotImplementedError
class ClevrDialogCaptionDataset(ClevrDialogDataset):
def __init__(self, dataPath, vocabPath, split, name, indStart=0, indEnd=-1):
super(ClevrDialogCaptionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
self.captions = torch.LongTensor(np.asarray(self.data["captions"], dtype=np.int64)[indStart: indEnd])
self.captionsPrgs = torch.LongTensor(np.asarray(self.data["captionProgs"], dtype=np.int64)[indStart: indEnd])
self.name = name
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
assert idx < len(self)
caption = self.captions[idx][:16]
captionPrg = self.captionsPrgs[idx]
return caption, captionPrg
class ClevrDialogQuestionDataset(ClevrDialogDataset):
def __init__(self, dataPath, vocabPath, split, name, train=True, indStart=0, indEnd=-1):
super(ClevrDialogQuestionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
self.questions = torch.LongTensor(np.asarray(self.data["questions"], dtype=np.int64)[indStart: indEnd])
self.quesProgs = torch.LongTensor(np.asarray(self.data["questionProgs"], dtype=np.int64)[indStart: indEnd])
self.questionRounds = torch.LongTensor(np.asarray(self.data["questionRounds"], dtype=np.int64)[indStart: indEnd])
self.questionImgIdx = torch.LongTensor(np.asarray(self.data["questionImgIdx"], dtype=np.int64)[indStart: indEnd])
self.histories = torch.LongTensor(np.asarray(self.data["histories"], dtype=np.int64)[indStart: indEnd])
self.historiesProgs = torch.LongTensor(np.asarray(self.data["historiesProg"], dtype=np.int64)[indStart: indEnd])
self.answers = torch.LongTensor(np.asarray(self.data["answers"], dtype=np.int64)[indStart: indEnd])
self.name = name
self.train = train
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
assert idx < len(self)
question = self.questions[idx]
questionPrg = self.quesProgs[idx]
questionImgIdx = self.questionImgIdx[idx]
questionRound = self.questionRounds[idx]
history = self.histories[idx]
historiesProg = self.historiesProgs[idx]
answer = self.answers[idx]
if self.train:
return question, history, questionPrg, questionRound, answer
else:
return question, questionPrg, questionImgIdx, questionRound, history, historiesProg, answer