Initial commit
This commit is contained in:
commit
b5f3b728c3
53 changed files with 7008 additions and 0 deletions
0
core/data/.gitkeep
Normal file
0
core/data/.gitkeep
Normal file
103
core/data/dataset.py
Normal file
103
core/data/dataset.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import glob, os, json, pickle
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from core.data.utils import tokenize, ans_stat, proc_ques, qlen_to_key, ans_to_key
|
||||
|
||||
|
||||
class VideoQA_Dataset(Dataset):
|
||||
def __init__(self, __C):
|
||||
super(VideoQA_Dataset, self).__init__()
|
||||
self.__C = __C
|
||||
self.ans_size = __C.NUM_ANS
|
||||
# load raw data
|
||||
with open(__C.QA_PATH[__C.RUN_MODE], 'r') as f:
|
||||
self.raw_data = json.load(f)
|
||||
self.data_size = len(self.raw_data)
|
||||
|
||||
splits = __C.SPLIT[__C.RUN_MODE].split('+')
|
||||
|
||||
frames_list = glob.glob(__C.FRAMES + '*.pt')
|
||||
clips_list = glob.glob(__C.CLIPS + '*.pt')
|
||||
if 'msvd' in self.C.DATASET_PATH.lower():
|
||||
vid_ids = [int(s.split('/')[-1].split('.')[0][3:]) for s in frames_list]
|
||||
else:
|
||||
vid_ids = [int(s.split('/')[-1].split('.')[0][5:]) for s in frames_list]
|
||||
self.frames_dict = {k: v for (k,v) in zip(vid_ids, frames_list)}
|
||||
self.clips_dict = {k: v for (k,v) in zip(vid_ids, clips_list)}
|
||||
del frames_list, clips_list
|
||||
|
||||
q_list = []
|
||||
a_list = []
|
||||
a_dict = defaultdict(lambda: 0)
|
||||
for split in ['train', 'val']:
|
||||
with open(__C.QA_PATH[split], 'r') as f:
|
||||
qa_data = json.load(f)
|
||||
for d in qa_data:
|
||||
q_list.append(d['question'])
|
||||
a_list = d['answer']
|
||||
if d['answer'] not in a_dict:
|
||||
a_dict[d['answer']] = 1
|
||||
else:
|
||||
a_dict[d['answer']] += 1
|
||||
|
||||
top_answers = sorted(a_dict, key=a_dict.get, reverse=True)
|
||||
self.qlen_bins_to_idx = {
|
||||
'1-3': 0,
|
||||
'4-8': 1,
|
||||
'9-15': 2,
|
||||
}
|
||||
self.ans_rare_to_idx = {
|
||||
'0-99': 0,
|
||||
'100-299': 1,
|
||||
'300-999': 2,
|
||||
|
||||
}
|
||||
self.qtypes_to_idx = {
|
||||
'what': 0,
|
||||
'who': 1,
|
||||
'how': 2,
|
||||
'when': 3,
|
||||
'where': 4,
|
||||
}
|
||||
|
||||
if __C.RUN_MODE == 'train':
|
||||
self.ans_list = top_answers[:self.ans_size]
|
||||
|
||||
self.ans_to_ix, self.ix_to_ans = ans_stat(self.ans_list)
|
||||
|
||||
self.token_to_ix, self.pretrained_emb = tokenize(q_list, __C.USE_GLOVE)
|
||||
self.token_size = self.token_to_ix.__len__()
|
||||
print('== Question token vocab size:', self.token_size)
|
||||
|
||||
self.idx_to_qtypes = {v: k for (k, v) in self.qtypes_to_idx.items()}
|
||||
self.idx_to_qlen_bins = {v: k for (k, v) in self.qlen_bins_to_idx.items()}
|
||||
self.idx_to_ans_rare = {v: k for (k, v) in self.ans_rare_to_idx.items()}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.raw_data[idx]
|
||||
ques = sample['question']
|
||||
q_type = self.qtypes_to_idx[ques.split(' ')[0]]
|
||||
ques_idx, qlen, _ = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN)
|
||||
qlen_bin = self.qlen_bins_to_idx[qlen_to_key(qlen)]
|
||||
|
||||
answer = sample['answer']
|
||||
answer = self.ans_to_ix.get(answer, np.random.randint(0, high=len(self.ans_list)))
|
||||
ans_rarity = self.ans_rare_to_idx[ans_to_key(answer)]
|
||||
|
||||
answer_one_hot = torch.zeros(self.ans_size)
|
||||
answer_one_hot[answer] = 1.0
|
||||
|
||||
vid_id = sample['video_id']
|
||||
frames = torch.load(open(self.frames_dict[vid_id], 'rb')).cpu()
|
||||
clips = torch.load(open(self.clips_dict[vid_id], 'rb')).cpu()
|
||||
|
||||
return torch.from_numpy(ques_idx).long(), frames, clips, answer_one_hot, torch.tensor(answer).long(), \
|
||||
torch.tensor(q_type).long(), torch.tensor(qlen_bin).long(), torch.tensor(ans_rarity).long()
|
||||
|
||||
def __len__(self):
|
||||
return self.data_size
|
182
core/data/preprocess.py
Normal file
182
core/data/preprocess.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import skvideo.io as skv
|
||||
import torch
|
||||
import pickle
|
||||
from PIL import Image
|
||||
import tqdm
|
||||
import numpy as np
|
||||
from model.C3D import C3D
|
||||
import json
|
||||
from torchvision.models import vgg19
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
|
||||
|
||||
def _select_frames(path, frame_num):
|
||||
"""Select representative frames for video.
|
||||
Ignore some frames both at begin and end of video.
|
||||
Args:
|
||||
path: Path of video.
|
||||
Returns:
|
||||
frames: list of frames.
|
||||
"""
|
||||
frames = list()
|
||||
video_data = skv.vread(path)
|
||||
total_frames = video_data.shape[0]
|
||||
# Ignore some frame at begin and end.
|
||||
for i in np.linspace(0, total_frames, frame_num + 2)[1:frame_num + 1]:
|
||||
frame_data = video_data[int(i)]
|
||||
img = Image.fromarray(frame_data)
|
||||
img = img.resize((224, 224), Image.BILINEAR)
|
||||
frame_data = np.array(img)
|
||||
frames.append(frame_data)
|
||||
return frames
|
||||
|
||||
def _select_clips(path, clip_num):
|
||||
"""Select self.batch_size clips for video. Each clip has 16 frames.
|
||||
Args:
|
||||
path: Path of video.
|
||||
Returns:
|
||||
clips: list of clips.
|
||||
"""
|
||||
clips = list()
|
||||
# video_info = skvideo.io.ffprobe(path)
|
||||
video_data = skv.vread(path)
|
||||
total_frames = video_data.shape[0]
|
||||
height = video_data[1]
|
||||
width = video_data.shape[2]
|
||||
for i in np.linspace(0, total_frames, clip_num + 2)[1:clip_num + 1]:
|
||||
# Select center frame first, then include surrounding frames
|
||||
clip_start = int(i) - 8
|
||||
clip_end = int(i) + 8
|
||||
if clip_start < 0:
|
||||
clip_end = clip_end - clip_start
|
||||
clip_start = 0
|
||||
if clip_end > total_frames:
|
||||
clip_start = clip_start - (clip_end - total_frames)
|
||||
clip_end = total_frames
|
||||
clip = video_data[clip_start:clip_end]
|
||||
new_clip = []
|
||||
for j in range(16):
|
||||
frame_data = clip[j]
|
||||
img = Image.fromarray(frame_data)
|
||||
img = img.resize((112, 112), Image.BILINEAR)
|
||||
frame_data = np.array(img) * 1.0
|
||||
# frame_data -= self.mean[j]
|
||||
new_clip.append(frame_data)
|
||||
clips.append(new_clip)
|
||||
return clips
|
||||
|
||||
def preprocess_videos(video_dir, frame_num, clip_num):
|
||||
frames_dir = os.path.join(os.path.dirname(video_dir), 'frames')
|
||||
os.mkdir(frames_dir)
|
||||
|
||||
clips_dir = os.path.join(os.path.dirname(video_dir), 'clips')
|
||||
os.mkdir(clips_dir)
|
||||
|
||||
for video_name in tqdm.tqdm(os.listdir(video_dir)):
|
||||
video_path = os.path.join(video_dir, video_name)
|
||||
frames = _select_frames(video_path, frame_num)
|
||||
clips = _select_clips(video_path, clip_num)
|
||||
|
||||
with open(os.path.join(frames_dir, video_name.split('.')[0] + '.pkl'), "wb") as f:
|
||||
pickle.dump(frames, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with open(os.path.join(clips_dir, video_name.split('.')[0] + '.pkl'), "wb") as f:
|
||||
pickle.dump(clips, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def generate_video_features(path_frames, path_clips, c3d_path):
|
||||
device = torch.device('cuda:0')
|
||||
frame_feat_dir = os.path.join(os.path.dirname(path_frames), 'frame_feat')
|
||||
os.makedirs(frame_feat_dir, exist_ok=True)
|
||||
|
||||
clip_feat_dir = os.path.join(os.path.dirname(path_frames), 'clip_feat')
|
||||
os.makedirs(clip_feat_dir, exist_ok=True)
|
||||
|
||||
cnn = vgg19(pretrained=True)
|
||||
in_features = cnn.classifier[-1].in_features
|
||||
cnn.classifier = nn.Sequential(
|
||||
*list(cnn.classifier.children())[:-1]) # remove last fc layer
|
||||
cnn.to(device).eval()
|
||||
c3d = C3D()
|
||||
c3d.load_state_dict(torch.load(c3d_path))
|
||||
c3d.to(device).eval()
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406),
|
||||
(0.229, 0.224, 0.225))])
|
||||
for vid_name in tqdm.tqdm(os.listdir(path_frames)):
|
||||
frame_path = os.path.join(path_frames, vid_name)
|
||||
clip_path = os.path.join(path_clips, vid_name)
|
||||
|
||||
frames = pickle.load(open(frame_path, 'rb'))
|
||||
clips = pickle.load(open(clip_path, 'rb'))
|
||||
|
||||
frames = [transform(f) for f in frames]
|
||||
frame_feat = []
|
||||
clip_feat = []
|
||||
|
||||
for frame in frames:
|
||||
with torch.no_grad():
|
||||
feat = cnn(frame.unsqueeze(0).to(device))
|
||||
frame_feat.append(feat)
|
||||
for clip in clips:
|
||||
# clip has shape (c x f x h x w)
|
||||
clip = torch.from_numpy(np.float32(np.array(clip)))
|
||||
clip = clip.transpose(3, 0)
|
||||
clip = clip.transpose(3, 1)
|
||||
clip = clip.transpose(3, 2).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
feat = c3d(clip)
|
||||
clip_feat.append(feat)
|
||||
frame_feat = torch.cat(frame_feat, dim=0)
|
||||
clip_feat = torch.cat(clip_feat, dim=0)
|
||||
|
||||
torch.save(frame_feat, os.path.join(frame_feat_dir, vid_name.split('.')[0] + '.pt'))
|
||||
torch.save(clip_feat, os.path.join(clip_feat_dir, vid_name.split('.')[0] + '.pt'))
|
||||
|
||||
def parse_args():
|
||||
'''
|
||||
Parse input arguments
|
||||
'''
|
||||
parser = argparse.ArgumentParser(description='Preprocessing Args')
|
||||
|
||||
parser.add_argument('--RAW_VID_PATH', dest='RAW_VID_PATH',
|
||||
help='The path to the raw videos',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
parser.add_argument('--FRAMES_OUTPUT_DIR', dest='FRAMES_OUTPUT_DIR',
|
||||
help='The directory where the processed frames and their features will be stored',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
parser.add_argument('--CLIPS_OUTPUT_DIR', dest='FRAMES_OUTPUT_DIR',
|
||||
help='The directory where the processed frames and their features will be stored',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
parser.add_argument('--C3D_PATH', dest='C3D_PATH',
|
||||
help='Pretrained C3D path',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
parser.add_argument('--NUM_SAMPLES', dest='NUM_SAMPLES',
|
||||
help='The number of frames/clips to be sampled from the video',
|
||||
default=20,
|
||||
type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
preprocess_videos(args.RAW_VID_PATH, args.NUM_SAMPLES, args.NUM_SAMPLES)
|
||||
frames_dir = os.path.join(os.path.dirname(args.RAW_VID_PATH), 'frames')
|
||||
clips_dir = os.path.join(os.path.dirname(args.RAW_VID_PATH), 'clips')
|
||||
generate_video_features(frames_dir, clips_dir)
|
81
core/data/utils.py
Normal file
81
core/data/utils.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import en_vectors_web_lg, random, re, json
|
||||
import numpy as np
|
||||
|
||||
def tokenize(ques_list, use_glove):
|
||||
token_to_ix = {
|
||||
'PAD': 0,
|
||||
'UNK': 1,
|
||||
}
|
||||
|
||||
spacy_tool = None
|
||||
pretrained_emb = []
|
||||
if use_glove:
|
||||
spacy_tool = en_vectors_web_lg.load()
|
||||
pretrained_emb.append(spacy_tool('PAD').vector)
|
||||
pretrained_emb.append(spacy_tool('UNK').vector)
|
||||
|
||||
for ques in ques_list:
|
||||
words = re.sub(
|
||||
r"([.,'!?\"()*#:;])",
|
||||
'',
|
||||
ques.lower()
|
||||
).replace('-', ' ').replace('/', ' ').split()
|
||||
|
||||
for word in words:
|
||||
if word not in token_to_ix:
|
||||
token_to_ix[word] = len(token_to_ix)
|
||||
if use_glove:
|
||||
pretrained_emb.append(spacy_tool(word).vector)
|
||||
|
||||
pretrained_emb = np.array(pretrained_emb)
|
||||
|
||||
return token_to_ix, pretrained_emb
|
||||
|
||||
|
||||
def proc_ques(ques, token_to_ix, max_token):
|
||||
ques_ix = np.zeros(max_token, np.int64)
|
||||
|
||||
words = re.sub(
|
||||
r"([.,'!?\"()*#:;])",
|
||||
'',
|
||||
ques.lower()
|
||||
).replace('-', ' ').replace('/', ' ').split()
|
||||
q_len = 0
|
||||
for ix, word in enumerate(words):
|
||||
if word in token_to_ix:
|
||||
ques_ix[ix] = token_to_ix[word]
|
||||
q_len += 1
|
||||
else:
|
||||
ques_ix[ix] = token_to_ix['UNK']
|
||||
|
||||
if ix + 1 == max_token:
|
||||
break
|
||||
|
||||
return ques_ix, q_len, len(words)
|
||||
|
||||
def ans_stat(ans_list):
|
||||
ans_to_ix, ix_to_ans = {}, {}
|
||||
for i, ans in enumerate(ans_list):
|
||||
ans_to_ix[ans] = i
|
||||
ix_to_ans[i] = ans
|
||||
|
||||
return ans_to_ix, ix_to_ans
|
||||
|
||||
def shuffle_list(ans_list):
|
||||
random.shuffle(ans_list)
|
||||
|
||||
def qlen_to_key(q_len):
|
||||
if 1<= q_len <=3:
|
||||
return '1-3'
|
||||
if 4<= q_len <=8:
|
||||
return '4-8'
|
||||
if 9<= q_len:
|
||||
return '9-15'
|
||||
|
||||
def ans_to_key(ans_idx):
|
||||
if 0 <= ans_idx <= 99 :
|
||||
return '0-99'
|
||||
if 100 <= ans_idx <= 299 :
|
||||
return '100-299'
|
||||
if 300 <= ans_idx <= 999 :
|
||||
return '300-999'
|
Loading…
Add table
Add a link
Reference in a new issue