Make code public
This commit is contained in:
commit
8e03ef1c38
49 changed files with 545354 additions and 0 deletions
20
custom_datasets/README.md
Normal file
20
custom_datasets/README.md
Normal file
|
@ -0,0 +1,20 @@
|
|||
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
|
||||
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
|
||||
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
|
||||
4. Segment the frames
|
||||
```shell
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
|
||||
```
|
||||
5. Embed the crops
|
||||
```shell
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
|
||||
|
||||
```
|
||||
6. Preprocess and log the data
|
||||
```shell
|
||||
python dataset.py --split train
|
||||
python dataset.py --split val
|
||||
|
||||
```
|
0
custom_datasets/__init__.py
Normal file
0
custom_datasets/__init__.py
Normal file
401
custom_datasets/avsd.py
Normal file
401
custom_datasets/avsd.py
Normal file
|
@ -0,0 +1,401 @@
|
|||
import os
|
||||
import pickle
|
||||
import pyhocon
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from argparse import ArgumentParser
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import BartTokenizer
|
||||
from itertools import chain
|
||||
|
||||
|
||||
ADDITIONAL_SPECIAL_TOKENS = [
|
||||
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||
|
||||
SPECIAL_TOKENS_DICT = {
|
||||
'bos_token': '<s>',
|
||||
'eos_token': '</s>',
|
||||
'pad_token': '<pad>',
|
||||
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||
}
|
||||
|
||||
S0_TOK = '<s0>' # I3D_flow
|
||||
S1_TOK = '<s1>' # I3D_rgb
|
||||
S2_TOK = '<s2>' # sam obj
|
||||
S3_TOK = '<s3>' # audio
|
||||
S4_TOK = '<s4>' # history
|
||||
S5_TOK = '<s5>' # question
|
||||
|
||||
|
||||
|
||||
def tokenize(obj, tokenizer):
|
||||
if isinstance(obj, str):
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
if isinstance(obj, dict):
|
||||
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||
return list(tokenize(o) for o in obj)
|
||||
|
||||
|
||||
class AVSDDataset(Dataset):
|
||||
def __init__(self, config, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
self.bart_max_input_len = config['bart_max_input_len']
|
||||
self.bart_size = config['bart_size']
|
||||
self.cap_sum = config['cap_sum']
|
||||
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
|
||||
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
|
||||
|
||||
if self.config['overfit'] > 0:
|
||||
self.paths = self.paths[:self.config['overfit_size']]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
pth = self.paths[index]
|
||||
with open(pth, 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
|
||||
|
||||
input_ids = item['input_ids']
|
||||
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
|
||||
|
||||
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
|
||||
question_interval = [history_end.item(), input_ids.size(0)]
|
||||
|
||||
lm_labels = item['lm_labels']
|
||||
i3d_rgb = item['i3d_rgb']
|
||||
i3d_flow = item['i3d_flow']
|
||||
sam = item['sam']
|
||||
vgg = item['vgg']
|
||||
vid = item['vid']
|
||||
|
||||
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
return result
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
|
||||
for i in batch:
|
||||
input_ids_list.append(i[0])
|
||||
lm_labels_list.append(i[1])
|
||||
history_interval_list.append(i[2])
|
||||
question_interval_list.append(i[3])
|
||||
i3d_rgb_list.append(i[4])
|
||||
i3d_flow_list.append(i[5])
|
||||
sam_list.append(i[6])
|
||||
vggish_list.append(i[7])
|
||||
vid_ids_list.append(i[8])
|
||||
|
||||
history_intervals = np.array(history_interval_list)
|
||||
question_intervals = np.array(question_interval_list)
|
||||
|
||||
|
||||
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
|
||||
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
|
||||
min_len_sam = min([feat.shape[0] for feat in sam_list])
|
||||
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
|
||||
|
||||
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
|
||||
|
||||
# Sample equally-distant features from the visual features for each sample within the batch
|
||||
for i in range(len(i3d_rgb_list)):
|
||||
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
|
||||
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
|
||||
|
||||
for i in range(len(i3d_flow_list)):
|
||||
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
|
||||
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
|
||||
|
||||
for i in range(len(sam_list)):
|
||||
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
sam_list[i] = sam_list[i][sample_idx_sam, :]
|
||||
sam = torch.from_numpy(np.array(sam_list)).float()
|
||||
|
||||
for i in range(len(vggish_list)):
|
||||
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
|
||||
vggish = torch.from_numpy(np.array(vggish_list)).float()
|
||||
|
||||
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
|
||||
|
||||
# All the visual features will not be masked because we do not perform any padding on them
|
||||
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
|
||||
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||
dummy = torch.ones((len(batch), min_length)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * sam_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * audio_sep, dummy,
|
||||
], dim=-1).long()
|
||||
|
||||
input_ids = self.padding(input_ids_list, pad_token)
|
||||
lm_labels = self.padding(lm_labels_list, -100)
|
||||
text_mask = input_ids != pad_token
|
||||
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
|
||||
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
|
||||
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
|
||||
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
|
||||
|
||||
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
|
||||
|
||||
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||
history_intervals += 4 * min_length + 4
|
||||
question_intervals += 4 * min_length + 4
|
||||
history_intervals = history_intervals.tolist()
|
||||
question_intervals = question_intervals.tolist()
|
||||
|
||||
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
|
||||
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
|
||||
|
||||
batch = {
|
||||
'input_ids': input_ids,
|
||||
'video_place_holder_ids': video_place_holder_ids,
|
||||
'i3d_rgb': i3d_rgb,
|
||||
'i3d_flow': i3d_flow,
|
||||
'sam': sam,
|
||||
'vggish': vggish,
|
||||
'lm_labels': lm_labels,
|
||||
'input_mask': input_mask,
|
||||
'i3d_rgb_interval': i3d_rgb_interval,
|
||||
'i3d_flow_interval': i3d_flow_interval,
|
||||
'sam_interval': sam_interval,
|
||||
'audio_interval': audio_interval,
|
||||
'history_intervals': history_intervals,
|
||||
'question_intervals': question_intervals,
|
||||
'vis_state_vector_idx': vis_state_vector_idx,
|
||||
'history_state_vector_idx': history_state_vector_idx,
|
||||
'question_state_vector_idx': question_state_vector_idx
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def get_dataset(config, split, tokenizer):
|
||||
if split != 'test':
|
||||
dialog_pth = config[f'avsd_{split}']
|
||||
else:
|
||||
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
|
||||
n_history = config['n_history']
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))
|
||||
dialog_list = []
|
||||
vid_set = set()
|
||||
undisclosed_only = split == 'test'
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
|
||||
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
|
||||
for dialog in pbar:
|
||||
if config['dstc'] != 10:
|
||||
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
|
||||
else:
|
||||
caption = [tokenize('no', tokenizer)]
|
||||
|
||||
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
|
||||
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
|
||||
vid = dialog["image_id"]
|
||||
vid_set.add(vid)
|
||||
if undisclosed_only:
|
||||
it = range(len(questions) - 1, len(questions))
|
||||
else:
|
||||
it = range(len(questions))
|
||||
qalist=[]
|
||||
history = []
|
||||
if undisclosed_only:
|
||||
for n in range(len(questions)-1):
|
||||
qalist.append(questions[n])
|
||||
qalist.append(answers[n])
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
for n in it:
|
||||
if undisclosed_only:
|
||||
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
history.append(question)
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
all_features = {}
|
||||
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
|
||||
|
||||
dataname = '<FeaType>/<ImageID>.npy'
|
||||
for ftype in fea_types:
|
||||
if undisclosed_only:
|
||||
basename = dataname.replace('<FeaType>', ftype+'_testset')
|
||||
else:
|
||||
basename = dataname.replace('<FeaType>', ftype)
|
||||
features = {}
|
||||
for vid in vid_set:
|
||||
filename = basename.replace('<ImageID>', vid)
|
||||
filepath = config['avsd_feature_path'] + filename
|
||||
features[vid] = filepath
|
||||
all_features[ftype] = features
|
||||
return dialog_list, all_features
|
||||
|
||||
|
||||
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
|
||||
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
||||
|
||||
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
|
||||
sep = eos
|
||||
|
||||
instance = {}
|
||||
instance["lm_labels"] = reply + [eos]
|
||||
caption = list(chain(*caption))
|
||||
|
||||
# Add state tokens if applicable
|
||||
if add_state_tokens:
|
||||
caption.insert(0, hist_state)
|
||||
history = deepcopy(history_orig)
|
||||
history[-1].insert(0, ques_state)
|
||||
else:
|
||||
history = history_orig
|
||||
|
||||
if not drop_caption:
|
||||
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
||||
|
||||
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
||||
# learn to copy it --> low train/val loss but no learning is happening
|
||||
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
||||
else:
|
||||
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
|
||||
|
||||
instance["input_ids"] = list(chain(*sequence))
|
||||
return instance
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='debug dataloader')
|
||||
parser.add_argument(
|
||||
'--split',
|
||||
type=str,
|
||||
default='train',
|
||||
help='train or val')
|
||||
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='mixer',
|
||||
help='model name to train or test')
|
||||
|
||||
parser.add_argument(
|
||||
'--log_dataset',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether or not to log the processed data')
|
||||
|
||||
parser.add_argument(
|
||||
'--add_state_tokens',
|
||||
action='store_true',
|
||||
default=True,
|
||||
help='Whether or not to add state tokens')
|
||||
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='processed/avsd',
|
||||
help='Output directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
split = args.split
|
||||
|
||||
config = pyhocon.ConfigFactory.parse_file(
|
||||
'config/mst_mixer.conf')[args.model]
|
||||
config['expand_rnd'] = False
|
||||
config['debugging'] = False
|
||||
config['overfit'] = False
|
||||
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
|
||||
if args.log_dataset:
|
||||
log_dir = os.path.join(args.log_dir, split)
|
||||
if not os.path.isdir(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
dialogs, features = get_dataset(config, split, tokenizer)
|
||||
pbar = tqdm(dialogs)
|
||||
pbar.set_description('[{}] Logging processed data'.format(split))
|
||||
counter = 0
|
||||
for dialog in pbar:
|
||||
vid = dialog['vid']
|
||||
his = dialog['history']
|
||||
cap = dialog['caption']
|
||||
ans = dialog['answer']
|
||||
|
||||
if np.random.rand() < config['caption_drop_rate']:
|
||||
instance = build_input_from_segments(
|
||||
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
|
||||
else:
|
||||
instance = build_input_from_segments(
|
||||
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
|
||||
|
||||
input_ids = torch.Tensor(instance["input_ids"]).long()
|
||||
lm_labels = torch.Tensor(instance["lm_labels"]).long()
|
||||
|
||||
vgg = np.load(features["vggish"][vid])
|
||||
i3d_flow = np.load(features["i3d_flow"][vid])
|
||||
i3d_rgb = np.load(features["i3d_rgb"][vid])
|
||||
sam = np.load(features["sam"][vid])
|
||||
|
||||
item = {
|
||||
'input_ids': input_ids,
|
||||
'lm_labels': lm_labels,
|
||||
'i3d_rgb': i3d_rgb,
|
||||
'i3d_flow': i3d_flow,
|
||||
'sam': sam,
|
||||
'vgg': vgg,
|
||||
'vid': vid
|
||||
}
|
||||
counter += 1
|
||||
pth = os.path.join(log_dir, str(counter) + '.pkl')
|
||||
with open(pth, 'wb') as f:
|
||||
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
avsd_dataset = AVSDDataset(config, 'val')
|
||||
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
|
||||
|
||||
for i, data in enumerate(avsd_dataloader):
|
||||
print('{}/{}'.format(i, len(avsd_dataloader)))
|
||||
print(avsd_dataset.max_len)
|
||||
|
||||
print('[INFO] Done...')
|
211
custom_datasets/nextqa.py
Normal file
211
custom_datasets/nextqa.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import BartTokenizer
|
||||
from itertools import chain
|
||||
|
||||
|
||||
ADDITIONAL_SPECIAL_TOKENS = [
|
||||
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||
|
||||
SPECIAL_TOKENS_DICT = {
|
||||
'bos_token': '<s>',
|
||||
'eos_token': '</s>',
|
||||
'pad_token': '<pad>',
|
||||
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||
}
|
||||
|
||||
S0_TOK = '<s0>' # frame
|
||||
S1_TOK = '<s1>' # mot
|
||||
S2_TOK = '<s2>' # question
|
||||
|
||||
def load_file(file_name):
|
||||
annos = None
|
||||
if os.path.splitext(file_name)[-1] == '.csv':
|
||||
return pd.read_csv(file_name)
|
||||
with open(file_name, 'r') as fp:
|
||||
if os.path.splitext(file_name)[1]== '.txt':
|
||||
annos = fp.readlines()
|
||||
annos = [line.rstrip() for line in annos]
|
||||
if os.path.splitext(file_name)[1] == '.json':
|
||||
annos = json.load(fp)
|
||||
|
||||
return annos
|
||||
|
||||
|
||||
def tokenize(obj, tokenizer):
|
||||
if isinstance(obj, str):
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
if isinstance(obj, dict):
|
||||
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||
return list(tokenize(o) for o in obj)
|
||||
|
||||
|
||||
class NextQADataset(Dataset):
|
||||
def __init__(self, config, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
self.bart_max_input_len = config['bart_max_input_len']
|
||||
self.bart_size = config['bart_size']
|
||||
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||
|
||||
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
|
||||
self.sample_list = load_file(sample_list_file)
|
||||
|
||||
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||
print('Load {}...'.format(vid_feat_file))
|
||||
self.frame_feats = {}
|
||||
self.mot_feats = {}
|
||||
with h5py.File(vid_feat_file, 'r') as fp:
|
||||
vids = fp['ids']
|
||||
feats = fp['feat']
|
||||
for vid, feat in zip(vids, feats):
|
||||
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||
|
||||
if self.config['overfit_size'] > 0:
|
||||
self.sample_list = self.sample_list[:self.config['overfit_size']]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def get_video_feature(self, video_name):
|
||||
"""
|
||||
:param video_name:
|
||||
:return:
|
||||
"""
|
||||
|
||||
app_feat = self.frame_feats[video_name]
|
||||
app_feat = torch.from_numpy(app_feat).type(torch.float32)
|
||||
|
||||
mot_feat = self.mot_feats[video_name]
|
||||
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
|
||||
|
||||
return app_feat, mot_feat
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
cur_sample = self.sample_list.loc[idx]
|
||||
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||
|
||||
input_ids = tokenize(ques, self.tokenizer)
|
||||
lm_labels = tokenize(ans, self.tokenizer)
|
||||
|
||||
app_feat, mot_feat = self.get_video_feature(video_name)
|
||||
|
||||
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
|
||||
|
||||
# Add state tokens
|
||||
input_ids.insert(0, ques_state)
|
||||
lm_labels.append(eos)
|
||||
question_interval = [0, len(input_ids)]
|
||||
|
||||
input_ids = torch.Tensor(input_ids).long()
|
||||
lm_labels = torch.Tensor(lm_labels).long()
|
||||
|
||||
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
|
||||
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
return result
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
|
||||
for i in batch:
|
||||
input_ids_list.append(i[0])
|
||||
lm_labels_list.append(i[1])
|
||||
app_feat_list.append(i[2])
|
||||
mot_feat_list.append(i[3])
|
||||
question_interval_list.append(i[4])
|
||||
vid_ids_list.append(i[5])
|
||||
|
||||
app_feats = torch.stack(app_feat_list, dim=0).float()
|
||||
mot_feats = torch.stack(mot_feat_list, dim=0).float()
|
||||
|
||||
question_intervals = np.array(question_interval_list)
|
||||
|
||||
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||
['<pad>', '<s0>', '<s1>', '<place_holder>'])
|
||||
|
||||
# All the visual features will not be masked because we do not perform any padding on them
|
||||
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
|
||||
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||
dummy = torch.ones((len(batch), 16)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((len(batch), 1)) * app_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * mot_sep, dummy,
|
||||
], dim=-1).long()
|
||||
|
||||
input_ids = self.padding(input_ids_list, pad_token)
|
||||
lm_labels = self.padding(lm_labels_list, -100)
|
||||
text_mask = input_ids != pad_token
|
||||
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
app_interval = [0, 16 + 1] # the last token is not part of this modality
|
||||
mot_interval = [16 + 1, 2 * 16 + 2]
|
||||
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
|
||||
|
||||
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||
question_intervals += 2 * 16 + 2
|
||||
question_intervals = question_intervals.tolist()
|
||||
|
||||
question_state_vector_idx = [x[0] for x in question_intervals]
|
||||
|
||||
batch = {
|
||||
'input_ids': input_ids,
|
||||
'video_place_holder_ids': video_place_holder_ids,
|
||||
'app_feats': app_feats,
|
||||
'mot_feats': mot_feats,
|
||||
'lm_labels': lm_labels,
|
||||
'input_mask': input_mask,
|
||||
'app_interval': app_interval,
|
||||
'mot_interval': mot_interval,
|
||||
'question_intervals': question_intervals,
|
||||
'vis_state_vector_idx': vis_state_vector_idx,
|
||||
'question_state_vector_idx': question_state_vector_idx
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_dataset(config, split):
|
||||
|
||||
bart_max_input_len = config['bart_max_input_len']
|
||||
bart_size = config['bart_size']
|
||||
|
||||
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
|
||||
sample_list = load_file(sample_list_file)
|
||||
|
||||
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||
print('Load {}...'.format(vid_feat_file))
|
||||
app_feats = {}
|
||||
mot_feats = {}
|
||||
with h5py.File(vid_feat_file, 'r') as fp:
|
||||
vids = fp['ids']
|
||||
feats = fp['feat']
|
||||
for vid, feat in zip(vids, feats):
|
||||
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||
|
||||
return sample_list, app_feats, mot_feats
|
||||
|
179
custom_datasets/segment.py
Normal file
179
custom_datasets/segment.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
||||
from tqdm import tqdm
|
||||
from argparse import ArgumentParser
|
||||
import pickle
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'--sam_ckpt',
|
||||
type=str,
|
||||
help='SAM checkpoint to be used'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--avsd_root',
|
||||
type=str,
|
||||
help='Directory where the individual AVSD frames are located'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--crop_root',
|
||||
type=str,
|
||||
help='Directory where the individual crops (objects) will be saved'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embed_root',
|
||||
type=str,
|
||||
help='Directory where the individual embeddings will be saved'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
choices=['segment', 'embed'],
|
||||
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--start',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Start index of the partition'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--end',
|
||||
type=int,
|
||||
default=1968,
|
||||
help='End index of the partition'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def partition_ids(avsd_ids, start, end):
|
||||
avsd_ids.sort()
|
||||
assert start < end
|
||||
assert start >= 0 and end <= len(avsd_ids)
|
||||
avsd_ids_partition = avsd_ids[start:end]
|
||||
return avsd_ids_partition
|
||||
|
||||
|
||||
def get_middle_frames(avsd_ids_partition, avsd_root):
|
||||
pbar = tqdm(avsd_ids_partition)
|
||||
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
|
||||
path_list = []
|
||||
for avsd_id in pbar:
|
||||
frames = os.listdir(os.path.join(avsd_root, avsd_id))
|
||||
if 'test' in avsd_root:
|
||||
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
|
||||
else:
|
||||
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
|
||||
middle_frame = frames[int(len(frames)/2)]
|
||||
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
|
||||
path_list.append(middle_frame)
|
||||
return path_list
|
||||
|
||||
|
||||
def segment_images(sam, path_list, crop_root):
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
pbar = tqdm(path_list)
|
||||
pbar.set_description('Detecting Objects')
|
||||
for pth in pbar:
|
||||
vid_id = pth.split('/')[-2]
|
||||
crop_dir = os.path.join(crop_root, vid_id)
|
||||
if not os.path.isdir(crop_dir):
|
||||
os.makedirs(crop_dir)
|
||||
|
||||
image = cv2.imread(pth)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
masks = mask_generator.generate(image)
|
||||
masks.sort(key=lambda e: e['stability_score'], reverse=True)
|
||||
if len(masks) > 36:
|
||||
masks = masks[:36]
|
||||
for i, mask in enumerate(masks):
|
||||
crop = image[
|
||||
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
|
||||
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
|
||||
:
|
||||
]
|
||||
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
|
||||
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
|
||||
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
|
||||
|
||||
print('[INFO] Done...')
|
||||
|
||||
|
||||
def embed_objects(sam, crop_ids, crop_root, embed_root):
|
||||
predictor = SamPredictor(sam)
|
||||
pbar = tqdm(crop_ids)
|
||||
pbar.set_description('Embedding Objects')
|
||||
for vid_id in pbar:
|
||||
embeds = []
|
||||
crop_dir = os.path.join(crop_root, vid_id)
|
||||
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
|
||||
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
|
||||
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
|
||||
for cp in crop_paths:
|
||||
crop = cv2.imread(cp)
|
||||
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||
predictor.set_image(crop)
|
||||
embed_crop = predictor.get_image_embedding()
|
||||
embed_crop = embed_crop.mean(-1).mean(-1)
|
||||
|
||||
crop_flipped = cv2.flip(crop, 1)
|
||||
predictor.set_image(crop_flipped)
|
||||
embed_crop_flipped = predictor.get_image_embedding()
|
||||
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
|
||||
|
||||
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
|
||||
# embed = embed.copy().cpu()
|
||||
embeds.append(embed)
|
||||
|
||||
embeds = torch.cat(embeds, 0).cpu().numpy()
|
||||
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
|
||||
|
||||
print('[INFO] Done...')
|
||||
|
||||
|
||||
def segment(args, sam):
|
||||
avsd_ids = os.listdir(args.avsd_root)
|
||||
avsd_ids.sort()
|
||||
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
|
||||
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
|
||||
|
||||
if not os.path.isdir(args.crop_root):
|
||||
os.makedirs(args.crop_root)
|
||||
segment_images(sam, path_list, args.crop_root)
|
||||
|
||||
|
||||
def embed(args, sam):
|
||||
crop_ids = os.listdir(args.crop_root)
|
||||
crop_ids.sort()
|
||||
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
|
||||
if not os.path.isdir(args.embed_root):
|
||||
os.makedirs(args.embed_root)
|
||||
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
sam = sam_model_registry['vit_h'](
|
||||
checkpoint=args.sam_ckpt)
|
||||
device = 'cuda'
|
||||
sam.to(device=device)
|
||||
|
||||
assert args.mode in ['segment', 'embed']
|
||||
if args.mode == 'segment':
|
||||
segment(args, sam)
|
||||
else:
|
||||
embed(args, sam)
|
Loading…
Add table
Add a link
Reference in a new issue