205 lines
7.3 KiB
Python
205 lines
7.3 KiB
Python
# coding: utf-8
|
|
# author: noctli
|
|
import json
|
|
import os
|
|
import pickle
|
|
import logging
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import torch
|
|
import torch.utils.data
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
from itertools import chain
|
|
from torchvision import transforms
|
|
from .utils import type_transform_helper
|
|
from itertools import chain
|
|
from .video_utils import read_frames_decord
|
|
|
|
|
|
def tokenize(text, tokenizer, return_tensor=False):
|
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
|
if return_tensor:
|
|
return torch.tensor(tokenized_text).long()
|
|
return tokenized_text
|
|
|
|
|
|
def get_dataset(config, split):
|
|
if split != 'test':
|
|
dialog_pth = config[f'anno_avsd_{split}']
|
|
else:
|
|
dialog_pth = config['anno_avsd_test_dstc_{}'.format(config['dstc'])]
|
|
n_history = config['num_hist_turns_avsd']
|
|
undisclosed_only = split == 'test'
|
|
dialog_data = json.load(open(dialog_pth, 'r'))
|
|
dialog_list = []
|
|
vid_set = set()
|
|
pbar = tqdm(dialog_data['dialogs'])
|
|
pbar.set_description('[INFO] Loading AVSD - {}'.format(split))
|
|
for dialog in pbar:
|
|
# if config['dstc'] != 10:
|
|
caption = dialog['caption']
|
|
summary = dialog['summary']
|
|
# else:
|
|
# caption = 'no'
|
|
# summary = 'no'
|
|
|
|
questions = [d['question'] for d in dialog['dialog']]
|
|
answers = [d['answer'] for d in dialog['dialog']]
|
|
vid = dialog["image_id"]
|
|
vid_set.add(vid)
|
|
if undisclosed_only:
|
|
it = range(len(questions) - 1, len(questions))
|
|
else:
|
|
it = range(len(questions))
|
|
qalist=[]
|
|
history = []
|
|
if undisclosed_only:
|
|
for n in range(len(questions)-1):
|
|
qalist.append(questions[n])
|
|
qalist.append(answers[n])
|
|
history=qalist[max(-len(qalist),-n_history*2):]
|
|
for n in it:
|
|
if undisclosed_only:
|
|
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
|
question = questions[n]
|
|
answer = answers[n]
|
|
history.append(question)
|
|
if n_history == 0:
|
|
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'summary': summary}
|
|
else:
|
|
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'summary': summary}
|
|
|
|
dialog_list.append(item)
|
|
qalist.append(question)
|
|
qalist.append(answer)
|
|
history=qalist[max(-len(qalist),-n_history*2):]
|
|
|
|
return dialog_list
|
|
|
|
|
|
def build_input_from_segments(caption, history, reply, tokenizer, drop_caption=False):
|
|
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
|
|
|
bos, eos = tokenizer.convert_tokens_to_ids(['<s>', '</s>'])
|
|
sep = eos
|
|
|
|
instance = {}
|
|
instance["lm_labels"] = reply + [eos]
|
|
caption = list(chain(*caption))
|
|
|
|
if not drop_caption:
|
|
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
|
|
|
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
|
# learn to copy it --> low train/val loss but no learning is happening
|
|
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
|
else:
|
|
sequence = [[bos]] + [[sep] + s for s in history] + [[eos]]
|
|
|
|
instance["input_ids"] = list(chain(*sequence))
|
|
return instance
|
|
|
|
|
|
class AVSDDataSet(Dataset):
|
|
def __init__(self, config, medium, vis_processor, text_processor, split
|
|
# tokenizer, features=None, drop_rate=0.0, train=True
|
|
):
|
|
self.config = config
|
|
self.medium = medium
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
self.split = split
|
|
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
|
self.dialogs = get_dataset(config, split)
|
|
|
|
if split == 'test':
|
|
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
|
|
|
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
|
if num_samples > 0:
|
|
self.dialogs = self.dialogs[:num_samples]
|
|
|
|
def __len__(self):
|
|
return len(self.dialogs)
|
|
|
|
def load_vid(self, vid_id):
|
|
vid_dir_path = os.path.join(self.root_vis, vid_id + '.mp4')
|
|
|
|
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
|
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
|
|
|
vis = torch.cat(frames, dim=0)
|
|
return vis
|
|
|
|
def load_vid_old(self, vid_id):
|
|
# if vid_id == 'QQM8M':
|
|
# print('bla')
|
|
vid_dir_path = os.path.join(self.root_vis, vid_id)
|
|
frame_paths = [os.path.join(vid_dir_path, f) for f in os.listdir(vid_dir_path)]
|
|
frame_paths.sort()
|
|
num_avail_frames = len(frame_paths)
|
|
delta = int(num_avail_frames / (self.config['num_frames'] - 1))
|
|
ran = list(range(0, num_avail_frames, delta))
|
|
if len(ran) < self.config['num_frames']:
|
|
ran.extend([num_avail_frames - 1 for _ in range(self.config['num_frames'] - len(ran))])
|
|
if len(ran) > self.config['num_frames']:
|
|
ran = ran[:self.config['num_frames']]
|
|
assert len(ran) == self.config['num_frames'], f"vid {vid_id} - loaded {len(ran)}/{len(frame_paths)} frames"
|
|
frame_paths = [frame_paths[i] for i in ran]
|
|
vis = [Image.open(p).convert('RGB') for p in frame_paths]
|
|
vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
|
vis = torch.cat(vis, dim=0)
|
|
vis = self.trans(vis)
|
|
return vis
|
|
|
|
def __getitem__(self, index):
|
|
dialog = self.dialogs[index]
|
|
vid_id = dialog['vid']
|
|
|
|
caption = dialog['caption']
|
|
summary = dialog['summary']
|
|
history = dialog['history']
|
|
answer = dialog['answer']
|
|
|
|
caption = self.text_processor(caption)
|
|
summary = self.text_processor(summary)
|
|
if self.config.dstc != 10:
|
|
caption = caption + ' ' + summary
|
|
|
|
history = [self.text_processor(h) for h in history]
|
|
answer = self.text_processor(answer, remove_period=True)
|
|
|
|
if self.config.embed_from_llm:
|
|
if self.config.llm_family in ['llama', 'mistral']:
|
|
cls_tok = ''
|
|
sep_tok = ' '
|
|
bos_tok = '<s>'
|
|
eos_tok = '</s>'
|
|
else:
|
|
cls_tok = '<s>'
|
|
sep_tok = '</s>'
|
|
bos_tok = '<pad>'
|
|
eos_tok = '</s>'
|
|
else:
|
|
cls_tok = '[CLS]'
|
|
sep_tok = '[SEP]'
|
|
bos_tok = '[SEP]'
|
|
eos_tok = '[SEP]'
|
|
|
|
caption = cls_tok + caption + sep_tok
|
|
history = sep_tok.join(history)
|
|
history = history + sep_tok
|
|
|
|
# load the video frames
|
|
vis = self.load_vid(vid_id)
|
|
|
|
return vis, caption, history, answer, vid_id
|
|
|
|
|
|
def load_avsd_dataset(config, vis_processor, text_processor, split):
|
|
# data_file = config['anno_avsd_{}'.format(split)]
|
|
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
|
dataset = AVSDDataSet(config, 'avsd', vis_processor, text_processor, split)
|
|
return dataset
|