Initial commit

This commit is contained in:
Adnen Abdessaied 2022-03-30 10:46:35 +02:00
commit b5f3b728c3
53 changed files with 7008 additions and 0 deletions

0
core/data/.gitkeep Normal file
View file

103
core/data/dataset.py Normal file
View file

@ -0,0 +1,103 @@
import glob, os, json, pickle
import numpy as np
from collections import defaultdict
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from core.data.utils import tokenize, ans_stat, proc_ques, qlen_to_key, ans_to_key
class VideoQA_Dataset(Dataset):
def __init__(self, __C):
super(VideoQA_Dataset, self).__init__()
self.__C = __C
self.ans_size = __C.NUM_ANS
# load raw data
with open(__C.QA_PATH[__C.RUN_MODE], 'r') as f:
self.raw_data = json.load(f)
self.data_size = len(self.raw_data)
splits = __C.SPLIT[__C.RUN_MODE].split('+')
frames_list = glob.glob(__C.FRAMES + '*.pt')
clips_list = glob.glob(__C.CLIPS + '*.pt')
if 'msvd' in self.C.DATASET_PATH.lower():
vid_ids = [int(s.split('/')[-1].split('.')[0][3:]) for s in frames_list]
else:
vid_ids = [int(s.split('/')[-1].split('.')[0][5:]) for s in frames_list]
self.frames_dict = {k: v for (k,v) in zip(vid_ids, frames_list)}
self.clips_dict = {k: v for (k,v) in zip(vid_ids, clips_list)}
del frames_list, clips_list
q_list = []
a_list = []
a_dict = defaultdict(lambda: 0)
for split in ['train', 'val']:
with open(__C.QA_PATH[split], 'r') as f:
qa_data = json.load(f)
for d in qa_data:
q_list.append(d['question'])
a_list = d['answer']
if d['answer'] not in a_dict:
a_dict[d['answer']] = 1
else:
a_dict[d['answer']] += 1
top_answers = sorted(a_dict, key=a_dict.get, reverse=True)
self.qlen_bins_to_idx = {
'1-3': 0,
'4-8': 1,
'9-15': 2,
}
self.ans_rare_to_idx = {
'0-99': 0,
'100-299': 1,
'300-999': 2,
}
self.qtypes_to_idx = {
'what': 0,
'who': 1,
'how': 2,
'when': 3,
'where': 4,
}
if __C.RUN_MODE == 'train':
self.ans_list = top_answers[:self.ans_size]
self.ans_to_ix, self.ix_to_ans = ans_stat(self.ans_list)
self.token_to_ix, self.pretrained_emb = tokenize(q_list, __C.USE_GLOVE)
self.token_size = self.token_to_ix.__len__()
print('== Question token vocab size:', self.token_size)
self.idx_to_qtypes = {v: k for (k, v) in self.qtypes_to_idx.items()}
self.idx_to_qlen_bins = {v: k for (k, v) in self.qlen_bins_to_idx.items()}
self.idx_to_ans_rare = {v: k for (k, v) in self.ans_rare_to_idx.items()}
def __getitem__(self, idx):
sample = self.raw_data[idx]
ques = sample['question']
q_type = self.qtypes_to_idx[ques.split(' ')[0]]
ques_idx, qlen, _ = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN)
qlen_bin = self.qlen_bins_to_idx[qlen_to_key(qlen)]
answer = sample['answer']
answer = self.ans_to_ix.get(answer, np.random.randint(0, high=len(self.ans_list)))
ans_rarity = self.ans_rare_to_idx[ans_to_key(answer)]
answer_one_hot = torch.zeros(self.ans_size)
answer_one_hot[answer] = 1.0
vid_id = sample['video_id']
frames = torch.load(open(self.frames_dict[vid_id], 'rb')).cpu()
clips = torch.load(open(self.clips_dict[vid_id], 'rb')).cpu()
return torch.from_numpy(ques_idx).long(), frames, clips, answer_one_hot, torch.tensor(answer).long(), \
torch.tensor(q_type).long(), torch.tensor(qlen_bin).long(), torch.tensor(ans_rarity).long()
def __len__(self):
return self.data_size

182
core/data/preprocess.py Normal file
View file

@ -0,0 +1,182 @@
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import skvideo.io as skv
import torch
import pickle
from PIL import Image
import tqdm
import numpy as np
from model.C3D import C3D
import json
from torchvision.models import vgg19
import torchvision.transforms as transforms
import torch.nn as nn
import argparse
def _select_frames(path, frame_num):
"""Select representative frames for video.
Ignore some frames both at begin and end of video.
Args:
path: Path of video.
Returns:
frames: list of frames.
"""
frames = list()
video_data = skv.vread(path)
total_frames = video_data.shape[0]
# Ignore some frame at begin and end.
for i in np.linspace(0, total_frames, frame_num + 2)[1:frame_num + 1]:
frame_data = video_data[int(i)]
img = Image.fromarray(frame_data)
img = img.resize((224, 224), Image.BILINEAR)
frame_data = np.array(img)
frames.append(frame_data)
return frames
def _select_clips(path, clip_num):
"""Select self.batch_size clips for video. Each clip has 16 frames.
Args:
path: Path of video.
Returns:
clips: list of clips.
"""
clips = list()
# video_info = skvideo.io.ffprobe(path)
video_data = skv.vread(path)
total_frames = video_data.shape[0]
height = video_data[1]
width = video_data.shape[2]
for i in np.linspace(0, total_frames, clip_num + 2)[1:clip_num + 1]:
# Select center frame first, then include surrounding frames
clip_start = int(i) - 8
clip_end = int(i) + 8
if clip_start < 0:
clip_end = clip_end - clip_start
clip_start = 0
if clip_end > total_frames:
clip_start = clip_start - (clip_end - total_frames)
clip_end = total_frames
clip = video_data[clip_start:clip_end]
new_clip = []
for j in range(16):
frame_data = clip[j]
img = Image.fromarray(frame_data)
img = img.resize((112, 112), Image.BILINEAR)
frame_data = np.array(img) * 1.0
# frame_data -= self.mean[j]
new_clip.append(frame_data)
clips.append(new_clip)
return clips
def preprocess_videos(video_dir, frame_num, clip_num):
frames_dir = os.path.join(os.path.dirname(video_dir), 'frames')
os.mkdir(frames_dir)
clips_dir = os.path.join(os.path.dirname(video_dir), 'clips')
os.mkdir(clips_dir)
for video_name in tqdm.tqdm(os.listdir(video_dir)):
video_path = os.path.join(video_dir, video_name)
frames = _select_frames(video_path, frame_num)
clips = _select_clips(video_path, clip_num)
with open(os.path.join(frames_dir, video_name.split('.')[0] + '.pkl'), "wb") as f:
pickle.dump(frames, f, protocol=pickle.HIGHEST_PROTOCOL)
with open(os.path.join(clips_dir, video_name.split('.')[0] + '.pkl'), "wb") as f:
pickle.dump(clips, f, protocol=pickle.HIGHEST_PROTOCOL)
def generate_video_features(path_frames, path_clips, c3d_path):
device = torch.device('cuda:0')
frame_feat_dir = os.path.join(os.path.dirname(path_frames), 'frame_feat')
os.makedirs(frame_feat_dir, exist_ok=True)
clip_feat_dir = os.path.join(os.path.dirname(path_frames), 'clip_feat')
os.makedirs(clip_feat_dir, exist_ok=True)
cnn = vgg19(pretrained=True)
in_features = cnn.classifier[-1].in_features
cnn.classifier = nn.Sequential(
*list(cnn.classifier.children())[:-1]) # remove last fc layer
cnn.to(device).eval()
c3d = C3D()
c3d.load_state_dict(torch.load(c3d_path))
c3d.to(device).eval()
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
for vid_name in tqdm.tqdm(os.listdir(path_frames)):
frame_path = os.path.join(path_frames, vid_name)
clip_path = os.path.join(path_clips, vid_name)
frames = pickle.load(open(frame_path, 'rb'))
clips = pickle.load(open(clip_path, 'rb'))
frames = [transform(f) for f in frames]
frame_feat = []
clip_feat = []
for frame in frames:
with torch.no_grad():
feat = cnn(frame.unsqueeze(0).to(device))
frame_feat.append(feat)
for clip in clips:
# clip has shape (c x f x h x w)
clip = torch.from_numpy(np.float32(np.array(clip)))
clip = clip.transpose(3, 0)
clip = clip.transpose(3, 1)
clip = clip.transpose(3, 2).unsqueeze(0).to(device)
with torch.no_grad():
feat = c3d(clip)
clip_feat.append(feat)
frame_feat = torch.cat(frame_feat, dim=0)
clip_feat = torch.cat(clip_feat, dim=0)
torch.save(frame_feat, os.path.join(frame_feat_dir, vid_name.split('.')[0] + '.pt'))
torch.save(clip_feat, os.path.join(clip_feat_dir, vid_name.split('.')[0] + '.pt'))
def parse_args():
'''
Parse input arguments
'''
parser = argparse.ArgumentParser(description='Preprocessing Args')
parser.add_argument('--RAW_VID_PATH', dest='RAW_VID_PATH',
help='The path to the raw videos',
required=True,
type=str)
parser.add_argument('--FRAMES_OUTPUT_DIR', dest='FRAMES_OUTPUT_DIR',
help='The directory where the processed frames and their features will be stored',
required=True,
type=str)
parser.add_argument('--CLIPS_OUTPUT_DIR', dest='FRAMES_OUTPUT_DIR',
help='The directory where the processed frames and their features will be stored',
required=True,
type=str)
parser.add_argument('--C3D_PATH', dest='C3D_PATH',
help='Pretrained C3D path',
required=True,
type=str)
parser.add_argument('--NUM_SAMPLES', dest='NUM_SAMPLES',
help='The number of frames/clips to be sampled from the video',
default=20,
type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
preprocess_videos(args.RAW_VID_PATH, args.NUM_SAMPLES, args.NUM_SAMPLES)
frames_dir = os.path.join(os.path.dirname(args.RAW_VID_PATH), 'frames')
clips_dir = os.path.join(os.path.dirname(args.RAW_VID_PATH), 'clips')
generate_video_features(frames_dir, clips_dir)

81
core/data/utils.py Normal file
View file

@ -0,0 +1,81 @@
import en_vectors_web_lg, random, re, json
import numpy as np
def tokenize(ques_list, use_glove):
token_to_ix = {
'PAD': 0,
'UNK': 1,
}
spacy_tool = None
pretrained_emb = []
if use_glove:
spacy_tool = en_vectors_web_lg.load()
pretrained_emb.append(spacy_tool('PAD').vector)
pretrained_emb.append(spacy_tool('UNK').vector)
for ques in ques_list:
words = re.sub(
r"([.,'!?\"()*#:;])",
'',
ques.lower()
).replace('-', ' ').replace('/', ' ').split()
for word in words:
if word not in token_to_ix:
token_to_ix[word] = len(token_to_ix)
if use_glove:
pretrained_emb.append(spacy_tool(word).vector)
pretrained_emb = np.array(pretrained_emb)
return token_to_ix, pretrained_emb
def proc_ques(ques, token_to_ix, max_token):
ques_ix = np.zeros(max_token, np.int64)
words = re.sub(
r"([.,'!?\"()*#:;])",
'',
ques.lower()
).replace('-', ' ').replace('/', ' ').split()
q_len = 0
for ix, word in enumerate(words):
if word in token_to_ix:
ques_ix[ix] = token_to_ix[word]
q_len += 1
else:
ques_ix[ix] = token_to_ix['UNK']
if ix + 1 == max_token:
break
return ques_ix, q_len, len(words)
def ans_stat(ans_list):
ans_to_ix, ix_to_ans = {}, {}
for i, ans in enumerate(ans_list):
ans_to_ix[ans] = i
ix_to_ans[i] = ans
return ans_to_ix, ix_to_ans
def shuffle_list(ans_list):
random.shuffle(ans_list)
def qlen_to_key(q_len):
if 1<= q_len <=3:
return '1-3'
if 4<= q_len <=8:
return '4-8'
if 9<= q_len:
return '9-15'
def ans_to_key(ans_idx):
if 0 <= ans_idx <= 99 :
return '0-99'
if 100 <= ans_idx <= 299 :
return '100-299'
if 300 <= ans_idx <= 999 :
return '300-999'