Make code public

This commit is contained in:
Adnen Abdessaied 2024-07-08 11:41:28 +02:00
commit 8e03ef1c38
49 changed files with 545354 additions and 0 deletions

20
custom_datasets/README.md Normal file
View file

@ -0,0 +1,20 @@
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
4. Segment the frames
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
```
5. Embed the crops
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
```
6. Preprocess and log the data
```shell
python dataset.py --split train
python dataset.py --split val
```

View file

401
custom_datasets/avsd.py Normal file
View file

@ -0,0 +1,401 @@
import os
import pickle
import pyhocon
from copy import deepcopy
import json
from tqdm import tqdm
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # I3D_flow
S1_TOK = '<s1>' # I3D_rgb
S2_TOK = '<s2>' # sam obj
S3_TOK = '<s3>' # audio
S4_TOK = '<s4>' # history
S5_TOK = '<s5>' # question
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class AVSDDataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.cap_sum = config['cap_sum']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
if self.config['overfit'] > 0:
self.paths = self.paths[:self.config['overfit_size']]
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
pth = self.paths[index]
with open(pth, 'rb') as f:
item = pickle.load(f)
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
input_ids = item['input_ids']
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
question_interval = [history_end.item(), input_ids.size(0)]
lm_labels = item['lm_labels']
i3d_rgb = item['i3d_rgb']
i3d_flow = item['i3d_flow']
sam = item['sam']
vgg = item['vgg']
vid = item['vid']
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
history_interval_list.append(i[2])
question_interval_list.append(i[3])
i3d_rgb_list.append(i[4])
i3d_flow_list.append(i[5])
sam_list.append(i[6])
vggish_list.append(i[7])
vid_ids_list.append(i[8])
history_intervals = np.array(history_interval_list)
question_intervals = np.array(question_interval_list)
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
min_len_sam = min([feat.shape[0] for feat in sam_list])
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
# Sample equally-distant features from the visual features for each sample within the batch
for i in range(len(i3d_rgb_list)):
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
for i in range(len(i3d_flow_list)):
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
for i in range(len(sam_list)):
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
sam_list[i] = sam_list[i][sample_idx_sam, :]
sam = torch.from_numpy(np.array(sam_list)).float()
for i in range(len(vggish_list)):
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
vggish = torch.from_numpy(np.array(vggish_list)).float()
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
torch.ones((len(batch), 1)) * sam_sep, dummy,
torch.ones((len(batch), 1)) * audio_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
history_intervals += 4 * min_length + 4
question_intervals += 4 * min_length + 4
history_intervals = history_intervals.tolist()
question_intervals = question_intervals.tolist()
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vggish': vggish,
'lm_labels': lm_labels,
'input_mask': input_mask,
'i3d_rgb_interval': i3d_rgb_interval,
'i3d_flow_interval': i3d_flow_interval,
'sam_interval': sam_interval,
'audio_interval': audio_interval,
'history_intervals': history_intervals,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'history_state_vector_idx': history_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split, tokenizer):
if split != 'test':
dialog_pth = config[f'avsd_{split}']
else:
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
n_history = config['n_history']
dialog_data = json.load(open(dialog_pth, 'r'))
dialog_list = []
vid_set = set()
undisclosed_only = split == 'test'
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
for dialog in pbar:
if config['dstc'] != 10:
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
else:
caption = [tokenize('no', tokenizer)]
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
all_features = {}
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
dataname = '<FeaType>/<ImageID>.npy'
for ftype in fea_types:
if undisclosed_only:
basename = dataname.replace('<FeaType>', ftype+'_testset')
else:
basename = dataname.replace('<FeaType>', ftype)
features = {}
for vid in vid_set:
filename = basename.replace('<ImageID>', vid)
filepath = config['avsd_feature_path'] + filename
features[vid] = filepath
all_features[ftype] = features
return dialog_list, all_features
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
sep = eos
instance = {}
instance["lm_labels"] = reply + [eos]
caption = list(chain(*caption))
# Add state tokens if applicable
if add_state_tokens:
caption.insert(0, hist_state)
history = deepcopy(history_orig)
history[-1].insert(0, ques_state)
else:
history = history_orig
if not drop_caption:
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
# learn to copy it --> low train/val loss but no learning is happening
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
else:
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
instance["input_ids"] = list(chain(*sequence))
return instance
def parse_args():
parser = ArgumentParser(description='debug dataloader')
parser.add_argument(
'--split',
type=str,
default='train',
help='train or val')
parser.add_argument(
'--model',
type=str,
default='mixer',
help='model name to train or test')
parser.add_argument(
'--log_dataset',
action='store_true',
default=False,
help='Whether or not to log the processed data')
parser.add_argument(
'--add_state_tokens',
action='store_true',
default=True,
help='Whether or not to add state tokens')
parser.add_argument(
'--log_dir',
type=str,
default='processed/avsd',
help='Output directory')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
split = args.split
config = pyhocon.ConfigFactory.parse_file(
'config/mst_mixer.conf')[args.model]
config['expand_rnd'] = False
config['debugging'] = False
config['overfit'] = False
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
if args.log_dataset:
log_dir = os.path.join(args.log_dir, split)
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
dialogs, features = get_dataset(config, split, tokenizer)
pbar = tqdm(dialogs)
pbar.set_description('[{}] Logging processed data'.format(split))
counter = 0
for dialog in pbar:
vid = dialog['vid']
his = dialog['history']
cap = dialog['caption']
ans = dialog['answer']
if np.random.rand() < config['caption_drop_rate']:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
else:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
input_ids = torch.Tensor(instance["input_ids"]).long()
lm_labels = torch.Tensor(instance["lm_labels"]).long()
vgg = np.load(features["vggish"][vid])
i3d_flow = np.load(features["i3d_flow"][vid])
i3d_rgb = np.load(features["i3d_rgb"][vid])
sam = np.load(features["sam"][vid])
item = {
'input_ids': input_ids,
'lm_labels': lm_labels,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vgg': vgg,
'vid': vid
}
counter += 1
pth = os.path.join(log_dir, str(counter) + '.pkl')
with open(pth, 'wb') as f:
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
avsd_dataset = AVSDDataset(config, 'val')
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
for i, data in enumerate(avsd_dataloader):
print('{}/{}'.format(i, len(avsd_dataloader)))
print(avsd_dataset.max_len)
print('[INFO] Done...')

211
custom_datasets/nextqa.py Normal file
View file

@ -0,0 +1,211 @@
import os
import pandas as pd
import h5py
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # frame
S1_TOK = '<s1>' # mot
S2_TOK = '<s2>' # question
def load_file(file_name):
annos = None
if os.path.splitext(file_name)[-1] == '.csv':
return pd.read_csv(file_name)
with open(file_name, 'r') as fp:
if os.path.splitext(file_name)[1]== '.txt':
annos = fp.readlines()
annos = [line.rstrip() for line in annos]
if os.path.splitext(file_name)[1] == '.json':
annos = json.load(fp)
return annos
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NextQADataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
self.sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
self.frame_feats = {}
self.mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
if self.config['overfit_size'] > 0:
self.sample_list = self.sample_list[:self.config['overfit_size']]
def __len__(self):
return len(self.sample_list)
def get_video_feature(self, video_name):
"""
:param video_name:
:return:
"""
app_feat = self.frame_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = self.mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
return app_feat, mot_feat
def __getitem__(self, idx):
cur_sample = self.sample_list.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
input_ids = tokenize(ques, self.tokenizer)
lm_labels = tokenize(ans, self.tokenizer)
app_feat, mot_feat = self.get_video_feature(video_name)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
lm_labels.append(eos)
question_interval = [0, len(input_ids)]
input_ids = torch.Tensor(input_ids).long()
lm_labels = torch.Tensor(lm_labels).long()
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
app_feat_list.append(i[2])
mot_feat_list.append(i[3])
question_interval_list.append(i[4])
vid_ids_list.append(i[5])
app_feats = torch.stack(app_feat_list, dim=0).float()
mot_feats = torch.stack(mot_feat_list, dim=0).float()
question_intervals = np.array(question_interval_list)
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * app_sep, dummy,
torch.ones((len(batch), 1)) * mot_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
question_intervals += 2 * 16 + 2
question_intervals = question_intervals.tolist()
question_state_vector_idx = [x[0] for x in question_intervals]
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'app_feats': app_feats,
'mot_feats': mot_feats,
'lm_labels': lm_labels,
'input_mask': input_mask,
'app_interval': app_interval,
'mot_interval': mot_interval,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split):
bart_max_input_len = config['bart_max_input_len']
bart_size = config['bart_size']
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
app_feats = {}
mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
return sample_list, app_feats, mot_feats

179
custom_datasets/segment.py Normal file
View file

@ -0,0 +1,179 @@
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from tqdm import tqdm
from argparse import ArgumentParser
import pickle
import cv2
import os
import torch
import numpy as np
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--sam_ckpt',
type=str,
help='SAM checkpoint to be used'
)
parser.add_argument(
'--avsd_root',
type=str,
help='Directory where the individual AVSD frames are located'
)
parser.add_argument(
'--crop_root',
type=str,
help='Directory where the individual crops (objects) will be saved'
)
parser.add_argument(
'--embed_root',
type=str,
help='Directory where the individual embeddings will be saved'
)
parser.add_argument(
'--mode',
type=str,
choices=['segment', 'embed'],
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
)
parser.add_argument(
'--start',
type=int,
default=0,
help='Start index of the partition'
)
parser.add_argument(
'--end',
type=int,
default=1968,
help='End index of the partition'
)
args = parser.parse_args()
return args
def partition_ids(avsd_ids, start, end):
avsd_ids.sort()
assert start < end
assert start >= 0 and end <= len(avsd_ids)
avsd_ids_partition = avsd_ids[start:end]
return avsd_ids_partition
def get_middle_frames(avsd_ids_partition, avsd_root):
pbar = tqdm(avsd_ids_partition)
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
path_list = []
for avsd_id in pbar:
frames = os.listdir(os.path.join(avsd_root, avsd_id))
if 'test' in avsd_root:
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
else:
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
middle_frame = frames[int(len(frames)/2)]
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
path_list.append(middle_frame)
return path_list
def segment_images(sam, path_list, crop_root):
mask_generator = SamAutomaticMaskGenerator(sam)
pbar = tqdm(path_list)
pbar.set_description('Detecting Objects')
for pth in pbar:
vid_id = pth.split('/')[-2]
crop_dir = os.path.join(crop_root, vid_id)
if not os.path.isdir(crop_dir):
os.makedirs(crop_dir)
image = cv2.imread(pth)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
masks.sort(key=lambda e: e['stability_score'], reverse=True)
if len(masks) > 36:
masks = masks[:36]
for i, mask in enumerate(masks):
crop = image[
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
:
]
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
print('[INFO] Done...')
def embed_objects(sam, crop_ids, crop_root, embed_root):
predictor = SamPredictor(sam)
pbar = tqdm(crop_ids)
pbar.set_description('Embedding Objects')
for vid_id in pbar:
embeds = []
crop_dir = os.path.join(crop_root, vid_id)
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
for cp in crop_paths:
crop = cv2.imread(cp)
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
predictor.set_image(crop)
embed_crop = predictor.get_image_embedding()
embed_crop = embed_crop.mean(-1).mean(-1)
crop_flipped = cv2.flip(crop, 1)
predictor.set_image(crop_flipped)
embed_crop_flipped = predictor.get_image_embedding()
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
# embed = embed.copy().cpu()
embeds.append(embed)
embeds = torch.cat(embeds, 0).cpu().numpy()
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
print('[INFO] Done...')
def segment(args, sam):
avsd_ids = os.listdir(args.avsd_root)
avsd_ids.sort()
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
if not os.path.isdir(args.crop_root):
os.makedirs(args.crop_root)
segment_images(sam, path_list, args.crop_root)
def embed(args, sam):
crop_ids = os.listdir(args.crop_root)
crop_ids.sort()
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
if not os.path.isdir(args.embed_root):
os.makedirs(args.embed_root)
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
if __name__ == '__main__':
args = parse_args()
sam = sam_model_registry['vit_h'](
checkpoint=args.sam_ckpt)
device = 'cuda'
sam.to(device=device)
assert args.mode in ['segment', 'embed']
if args.mode == 'segment':
segment(args, sam)
else:
embed(args, sam)