94 lines
3.7 KiB
Python
94 lines
3.7 KiB
Python
"""
|
|
author: Adnen Abdessaied
|
|
maintainer: "Adnen Abdessaied"
|
|
website: adnenabdessaied.de
|
|
version: 1.0.1
|
|
"""
|
|
|
|
import h5py
|
|
import json
|
|
import os
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def invertDict(_dict):
|
|
return {v: k for k, v in _dict.items()}
|
|
|
|
|
|
class ClevrDialogDataset(Dataset):
|
|
def __init__(self, dataPath, vocabPath, split, indStart=0, indEnd=-1):
|
|
super(ClevrDialogDataset, self).__init__()
|
|
self.data = h5py.File(dataPath, "r")
|
|
with open(vocabPath, "r") as f:
|
|
self.vocab = json.load(f)
|
|
self.vocab["idx_text_to_token"] = invertDict(self.vocab["text_token_to_idx"])
|
|
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
|
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
|
self.lenVocabText = len(self.vocab["text_token_to_idx"])
|
|
self.lenVocabProg = len(self.vocab["prog_token_to_idx"])
|
|
|
|
self.split = split
|
|
self.indStart = indStart
|
|
self.indEnd = indEnd
|
|
self.maxSamples = indEnd - indStart
|
|
self.maxLenProg = 6
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
|
|
class ClevrDialogCaptionDataset(ClevrDialogDataset):
|
|
def __init__(self, dataPath, vocabPath, split, name, indStart=0, indEnd=-1):
|
|
super(ClevrDialogCaptionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
|
self.captions = torch.LongTensor(np.asarray(self.data["captions"], dtype=np.int64)[indStart: indEnd])
|
|
self.captionsPrgs = torch.LongTensor(np.asarray(self.data["captionProgs"], dtype=np.int64)[indStart: indEnd])
|
|
self.name = name
|
|
|
|
def __len__(self):
|
|
return len(self.captions)
|
|
|
|
def __getitem__(self, idx):
|
|
assert idx < len(self)
|
|
caption = self.captions[idx][:16]
|
|
captionPrg = self.captionsPrgs[idx]
|
|
return caption, captionPrg
|
|
|
|
|
|
class ClevrDialogQuestionDataset(ClevrDialogDataset):
|
|
def __init__(self, dataPath, vocabPath, split, name, train=True, indStart=0, indEnd=-1):
|
|
super(ClevrDialogQuestionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
|
self.questions = torch.LongTensor(np.asarray(self.data["questions"], dtype=np.int64)[indStart: indEnd])
|
|
self.quesProgs = torch.LongTensor(np.asarray(self.data["questionProgs"], dtype=np.int64)[indStart: indEnd])
|
|
self.questionRounds = torch.LongTensor(np.asarray(self.data["questionRounds"], dtype=np.int64)[indStart: indEnd])
|
|
self.questionImgIdx = torch.LongTensor(np.asarray(self.data["questionImgIdx"], dtype=np.int64)[indStart: indEnd])
|
|
self.histories = torch.LongTensor(np.asarray(self.data["histories"], dtype=np.int64)[indStart: indEnd])
|
|
self.historiesProgs = torch.LongTensor(np.asarray(self.data["historiesProg"], dtype=np.int64)[indStart: indEnd])
|
|
|
|
self.answers = torch.LongTensor(np.asarray(self.data["answers"], dtype=np.int64)[indStart: indEnd])
|
|
self.name = name
|
|
self.train = train
|
|
|
|
def __len__(self):
|
|
return len(self.questions)
|
|
|
|
def __getitem__(self, idx):
|
|
assert idx < len(self)
|
|
question = self.questions[idx]
|
|
questionPrg = self.quesProgs[idx]
|
|
questionImgIdx = self.questionImgIdx[idx]
|
|
questionRound = self.questionRounds[idx]
|
|
|
|
history = self.histories[idx]
|
|
historiesProg = self.historiesProgs[idx]
|
|
|
|
answer = self.answers[idx]
|
|
if self.train:
|
|
return question, history, questionPrg, questionRound, answer
|
|
else:
|
|
return question, questionPrg, questionImgIdx, questionRound, history, historiesProg, answer
|