# coding: utf-8 # author: noctli import json import os import pickle import logging from tqdm import tqdm import numpy as np import torch import torch.utils.data from PIL import Image from torch.utils.data import Dataset from itertools import chain from torchvision import transforms from .utils import type_transform_helper from .utils import open_img def tokenize(text, tokenizer, return_tensor=False): tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) if return_tensor: return torch.tensor(tokenized_text).long() return tokenized_text def get_dataset(config, split): dialog_pth = config['anno_visdial_{}'.format(split)] dialog_data = json.load(open(dialog_pth, 'r'))['data'] all_answers = dialog_data['answers'] all_questions = dialog_data['questions'] dialog_list = [] n_history = config['num_hist_turns_visdial'] vid_set = set() undisclosed_only = False pbar = tqdm(dialog_data['dialogs']) pbar.set_description('[INFO] Loading VisDial - {}'.format(split)) for dialog in pbar: caption = dialog['caption'] + ' .' questions = [all_questions[d['question']] + ' ?' for d in dialog['dialog']] answers = [all_answers[d['answer']] + ' .' for d in dialog['dialog']] # answer_opts = [[all_answers[key] for key in d['answer_options']] for d in dialog['dialog']] # if 'test' in config['anno_visdial_{}'.format(split)]: # gt_indices = [-1 for _ in range(len(questions))] # else: # gt_indices = [d['gt_index'] for d in dialog['dialog']] vid = dialog['image_id'] vid_set.add(vid) if undisclosed_only: it = range(len(questions) - 1, len(questions)) else: it = range(len(questions)) qalist=[] history = [] if undisclosed_only: for n in range(len(questions)-1): qalist.append(questions[n]) qalist.append(answers[n]) history=qalist[max(-len(qalist),-n_history*2):] for n in it: if undisclosed_only: assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__' question = questions[n] answer = answers[n] # answer_opt = answer_opts[n] # gt_index = gt_indices[n] history.append(question) # if n_history == 0: # item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index} # else: # item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index} if n_history == 0: item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1} else: item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1} dialog_list.append(item) qalist.append(question) qalist.append(answer) history=qalist[max(-len(qalist),-n_history*2):] return dialog_list class VisDial(Dataset): def __init__(self, config, medium, vis_processor, text_processor, split # tokenizer, features=None, drop_rate=0.0, train=True ): self.config = config self.medium = medium self.split = split self.vis_processor = vis_processor self.text_processor = text_processor self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)] self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] self.dialogs = get_dataset(config, split) if split == 'test': self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']] num_samples = config['num_samples_{}'.format(self.medium)] if num_samples > 0: self.dialogs = self.dialogs[:num_samples] def __len__(self): return len(self.dialogs) def load_img(self, vid_id): file_pth = os.path.join(self.root_vis, f'{vid_id}.jpg') vis = open_img(file_pth) vis = self.vis_processor(vis).unsqueeze(0) return vis def __getitem__(self, index): dialog = self.dialogs[index] vid_id = dialog['vid'] caption = dialog['caption'] history = dialog['history'] answer = dialog['answer'] d_round = dialog['round'] caption = self.text_processor(caption) history = [self.text_processor(h) for h in history] answer = self.text_processor(answer, remove_period=True) # if self.split == 'test': # answer_opts = dialog['answer_opts'] # answer_opts = [self.text_processor(a) for a in answer_opts] # gt_index = dialog['gt_index'] # dialog_round = dialog['round'] # dense_key = str(vid_id) + '_' + str(dialog_round) # gt_relevance = self.dense_annos.get(dense_key, -1) # # eval_data = (answer_opts, gt_index, gt_relevance) if self.config.embed_from_llm: if self.config.llm_family in ['llama', 'mistral']: cls_tok = '' sep_tok = ' ' bos_tok = '' eos_tok = '' else: cls_tok = '' sep_tok = '' bos_tok = '' eos_tok = '' else: cls_tok = '[CLS]' sep_tok = '[SEP]' bos_tok = '[SEP]' eos_tok = '[SEP]' caption = cls_tok + caption + sep_tok history = sep_tok.join(history) history = history + sep_tok # load the video frames vis = self.load_img(vid_id) # if self.split == 'test': # return vis, caption, history, answer, vid_id, answer_opts, gt_relevance, gt_index # else: return vis, caption, history, answer, vid_id, d_round def load_visdial_dataset(config, vis_processor, text_processor, split): dataset = VisDial(config, 'visdial', vis_processor, text_processor, split) return dataset