initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
BIN
data/CGRUM.mp4
Normal file
BIN
data/CGRUM.mp4
Normal file
Binary file not shown.
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
205
datasets/avsd_dataset.py
Normal file
205
datasets/avsd_dataset.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# coding: utf-8
|
||||||
|
# author: noctli
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from itertools import chain
|
||||||
|
from torchvision import transforms
|
||||||
|
from .utils import type_transform_helper
|
||||||
|
from itertools import chain
|
||||||
|
from .video_utils import read_frames_decord
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(text, tokenizer, return_tensor=False):
|
||||||
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
if return_tensor:
|
||||||
|
return torch.tensor(tokenized_text).long()
|
||||||
|
return tokenized_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config, split):
|
||||||
|
if split != 'test':
|
||||||
|
dialog_pth = config[f'anno_avsd_{split}']
|
||||||
|
else:
|
||||||
|
dialog_pth = config['anno_avsd_test_dstc_{}'.format(config['dstc'])]
|
||||||
|
n_history = config['num_hist_turns_avsd']
|
||||||
|
undisclosed_only = split == 'test'
|
||||||
|
dialog_data = json.load(open(dialog_pth, 'r'))
|
||||||
|
dialog_list = []
|
||||||
|
vid_set = set()
|
||||||
|
pbar = tqdm(dialog_data['dialogs'])
|
||||||
|
pbar.set_description('[INFO] Loading AVSD - {}'.format(split))
|
||||||
|
for dialog in pbar:
|
||||||
|
# if config['dstc'] != 10:
|
||||||
|
caption = dialog['caption']
|
||||||
|
summary = dialog['summary']
|
||||||
|
# else:
|
||||||
|
# caption = 'no'
|
||||||
|
# summary = 'no'
|
||||||
|
|
||||||
|
questions = [d['question'] for d in dialog['dialog']]
|
||||||
|
answers = [d['answer'] for d in dialog['dialog']]
|
||||||
|
vid = dialog["image_id"]
|
||||||
|
vid_set.add(vid)
|
||||||
|
if undisclosed_only:
|
||||||
|
it = range(len(questions) - 1, len(questions))
|
||||||
|
else:
|
||||||
|
it = range(len(questions))
|
||||||
|
qalist=[]
|
||||||
|
history = []
|
||||||
|
if undisclosed_only:
|
||||||
|
for n in range(len(questions)-1):
|
||||||
|
qalist.append(questions[n])
|
||||||
|
qalist.append(answers[n])
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
for n in it:
|
||||||
|
if undisclosed_only:
|
||||||
|
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||||
|
question = questions[n]
|
||||||
|
answer = answers[n]
|
||||||
|
history.append(question)
|
||||||
|
if n_history == 0:
|
||||||
|
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'summary': summary}
|
||||||
|
else:
|
||||||
|
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'summary': summary}
|
||||||
|
|
||||||
|
dialog_list.append(item)
|
||||||
|
qalist.append(question)
|
||||||
|
qalist.append(answer)
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
|
||||||
|
return dialog_list
|
||||||
|
|
||||||
|
|
||||||
|
def build_input_from_segments(caption, history, reply, tokenizer, drop_caption=False):
|
||||||
|
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
||||||
|
|
||||||
|
bos, eos = tokenizer.convert_tokens_to_ids(['<s>', '</s>'])
|
||||||
|
sep = eos
|
||||||
|
|
||||||
|
instance = {}
|
||||||
|
instance["lm_labels"] = reply + [eos]
|
||||||
|
caption = list(chain(*caption))
|
||||||
|
|
||||||
|
if not drop_caption:
|
||||||
|
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
||||||
|
|
||||||
|
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
||||||
|
# learn to copy it --> low train/val loss but no learning is happening
|
||||||
|
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
||||||
|
else:
|
||||||
|
sequence = [[bos]] + [[sep] + s for s in history] + [[eos]]
|
||||||
|
|
||||||
|
instance["input_ids"] = list(chain(*sequence))
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDDataSet(Dataset):
|
||||||
|
def __init__(self, config, medium, vis_processor, text_processor, split
|
||||||
|
# tokenizer, features=None, drop_rate=0.0, train=True
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.medium = medium
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
self.dialogs = get_dataset(config, split)
|
||||||
|
|
||||||
|
if split == 'test':
|
||||||
|
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
|
||||||
|
|
||||||
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||||
|
if num_samples > 0:
|
||||||
|
self.dialogs = self.dialogs[:num_samples]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dialogs)
|
||||||
|
|
||||||
|
def load_vid(self, vid_id):
|
||||||
|
vid_dir_path = os.path.join(self.root_vis, vid_id + '.mp4')
|
||||||
|
|
||||||
|
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
||||||
|
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
||||||
|
|
||||||
|
vis = torch.cat(frames, dim=0)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
def load_vid_old(self, vid_id):
|
||||||
|
# if vid_id == 'QQM8M':
|
||||||
|
# print('bla')
|
||||||
|
vid_dir_path = os.path.join(self.root_vis, vid_id)
|
||||||
|
frame_paths = [os.path.join(vid_dir_path, f) for f in os.listdir(vid_dir_path)]
|
||||||
|
frame_paths.sort()
|
||||||
|
num_avail_frames = len(frame_paths)
|
||||||
|
delta = int(num_avail_frames / (self.config['num_frames'] - 1))
|
||||||
|
ran = list(range(0, num_avail_frames, delta))
|
||||||
|
if len(ran) < self.config['num_frames']:
|
||||||
|
ran.extend([num_avail_frames - 1 for _ in range(self.config['num_frames'] - len(ran))])
|
||||||
|
if len(ran) > self.config['num_frames']:
|
||||||
|
ran = ran[:self.config['num_frames']]
|
||||||
|
assert len(ran) == self.config['num_frames'], f"vid {vid_id} - loaded {len(ran)}/{len(frame_paths)} frames"
|
||||||
|
frame_paths = [frame_paths[i] for i in ran]
|
||||||
|
vis = [Image.open(p).convert('RGB') for p in frame_paths]
|
||||||
|
vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
||||||
|
vis = torch.cat(vis, dim=0)
|
||||||
|
vis = self.trans(vis)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dialog = self.dialogs[index]
|
||||||
|
vid_id = dialog['vid']
|
||||||
|
|
||||||
|
caption = dialog['caption']
|
||||||
|
summary = dialog['summary']
|
||||||
|
history = dialog['history']
|
||||||
|
answer = dialog['answer']
|
||||||
|
|
||||||
|
caption = self.text_processor(caption)
|
||||||
|
summary = self.text_processor(summary)
|
||||||
|
if self.config.dstc != 10:
|
||||||
|
caption = caption + ' ' + summary
|
||||||
|
|
||||||
|
history = [self.text_processor(h) for h in history]
|
||||||
|
answer = self.text_processor(answer, remove_period=True)
|
||||||
|
|
||||||
|
if self.config.embed_from_llm:
|
||||||
|
if self.config.llm_family in ['llama', 'mistral']:
|
||||||
|
cls_tok = ''
|
||||||
|
sep_tok = ' '
|
||||||
|
bos_tok = '<s>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '<s>'
|
||||||
|
sep_tok = '</s>'
|
||||||
|
bos_tok = '<pad>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '[CLS]'
|
||||||
|
sep_tok = '[SEP]'
|
||||||
|
bos_tok = '[SEP]'
|
||||||
|
eos_tok = '[SEP]'
|
||||||
|
|
||||||
|
caption = cls_tok + caption + sep_tok
|
||||||
|
history = sep_tok.join(history)
|
||||||
|
history = history + sep_tok
|
||||||
|
|
||||||
|
# load the video frames
|
||||||
|
vis = self.load_vid(vid_id)
|
||||||
|
|
||||||
|
return vis, caption, history, answer, vid_id
|
||||||
|
|
||||||
|
|
||||||
|
def load_avsd_dataset(config, vis_processor, text_processor, split):
|
||||||
|
# data_file = config['anno_avsd_{}'.format(split)]
|
||||||
|
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
||||||
|
dataset = AVSDDataSet(config, 'avsd', vis_processor, text_processor, split)
|
||||||
|
return dataset
|
279
datasets/champagne_dataset.py
Normal file
279
datasets/champagne_dataset.py
Normal file
|
@ -0,0 +1,279 @@
|
||||||
|
# coding: utf-8
|
||||||
|
# author: noctli
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from itertools import chain
|
||||||
|
from torchvision import transforms
|
||||||
|
from .utils import type_transform_helper
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(text, tokenizer, return_tensor=False):
|
||||||
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
if return_tensor:
|
||||||
|
return torch.tensor(tokenized_text).long()
|
||||||
|
return tokenized_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config, split):
|
||||||
|
|
||||||
|
dialog_pth = config['anno_visdial_{}'.format(split)]
|
||||||
|
dialog_data = json.load(open(dialog_pth, 'r'))['data']
|
||||||
|
all_answers = dialog_data['answers']
|
||||||
|
all_questions = dialog_data['questions']
|
||||||
|
dialog_list = []
|
||||||
|
n_history = config['num_hist_turns']
|
||||||
|
vid_set = set()
|
||||||
|
|
||||||
|
pbar = tqdm(dialog_data['dialogs'])
|
||||||
|
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
|
||||||
|
for dialog in pbar:
|
||||||
|
caption = dialog['caption']
|
||||||
|
questions = [all_questions[d['question']] for d in dialog['dialog']]
|
||||||
|
answers = [all_answers[d['answer']] for d in dialog['dialog']]
|
||||||
|
|
||||||
|
vid = dialog["image_id"]
|
||||||
|
vid_set.add(vid)
|
||||||
|
# if undisclosed_only:
|
||||||
|
# it = range(len(questions) - 1, len(questions))
|
||||||
|
# else:
|
||||||
|
it = range(len(questions))
|
||||||
|
qalist=[]
|
||||||
|
history = []
|
||||||
|
# if undisclosed_only:
|
||||||
|
# for n in range(len(questions)-1):
|
||||||
|
# qalist.append(questions[n])
|
||||||
|
# qalist.append(answers[n])
|
||||||
|
# history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
|
||||||
|
for n in it:
|
||||||
|
# if undisclosed_only:
|
||||||
|
# assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||||
|
question = questions[n]
|
||||||
|
answer = answers[n]
|
||||||
|
history.append(question)
|
||||||
|
if n_history == 0:
|
||||||
|
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||||
|
else:
|
||||||
|
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||||
|
dialog_list.append(item)
|
||||||
|
qalist.append(question)
|
||||||
|
qalist.append(answer)
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
return dialog_list
|
||||||
|
|
||||||
|
|
||||||
|
class Champagne(Dataset):
|
||||||
|
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.medium = medium
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||||
|
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
|
||||||
|
# get the mapping between caption and image/video
|
||||||
|
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||||
|
with open(mapping_path, 'rb') as f:
|
||||||
|
self.mapping = pickle.load(f)
|
||||||
|
|
||||||
|
ids = list(self.mapping.keys())
|
||||||
|
ids.sort()
|
||||||
|
|
||||||
|
# reserve some samples for validation
|
||||||
|
if split == 'train':
|
||||||
|
self.ids = ids[config.num_val_samples:]
|
||||||
|
elif split == 'val':
|
||||||
|
self.ids = ids[:config.num_val_samples]
|
||||||
|
|
||||||
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||||
|
if num_samples > 0:
|
||||||
|
self.ids = self.ids[:num_samples]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.ids)
|
||||||
|
|
||||||
|
|
||||||
|
def padding(self, seq, pad_token, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = max([i.size(0) for i in seq])
|
||||||
|
if len(seq[0].size()) == 1:
|
||||||
|
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||||
|
for i in range(len(seq)):
|
||||||
|
result[i, :seq[i].size(0)] = seq[i]
|
||||||
|
orig_len = [s.size(0) for s in seq]
|
||||||
|
return result, orig_len
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
item = self.mapping[self.ids[index]]
|
||||||
|
# load the videos
|
||||||
|
pth = os.path.join(self.root_vis, item['path'])
|
||||||
|
f_names = os.listdir(pth)
|
||||||
|
if len(f_names) == 0:
|
||||||
|
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||||
|
item = pickle.load(f)
|
||||||
|
|
||||||
|
# load the videos
|
||||||
|
pth = os.path.join(self.root_vis, item['path'])
|
||||||
|
f_names = os.listdir(pth)
|
||||||
|
f_names.sort()
|
||||||
|
|
||||||
|
if len(f_names) < self.config['num_frames']:
|
||||||
|
f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names))
|
||||||
|
elif len(f_names) > self.config['num_frames']:
|
||||||
|
f_names = f_names[:self.config['num_frames']]
|
||||||
|
|
||||||
|
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||||
|
try:
|
||||||
|
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||||
|
except:
|
||||||
|
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||||
|
item = pickle.load(f)
|
||||||
|
|
||||||
|
# load the videos
|
||||||
|
pth = os.path.join(self.root_vis, item['path'])
|
||||||
|
f_names = os.listdir(pth)
|
||||||
|
f_names.sort()
|
||||||
|
|
||||||
|
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||||
|
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||||
|
|
||||||
|
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||||
|
vis = torch.cat(vis, dim=0)
|
||||||
|
|
||||||
|
dialog = item['dialog']
|
||||||
|
|
||||||
|
caption = dialog['caption']
|
||||||
|
history = dialog['history']
|
||||||
|
answer = dialog['answer']
|
||||||
|
|
||||||
|
caption = self.text_processor(caption)
|
||||||
|
history = [self.text_processor(h) for h in history]
|
||||||
|
answer = self.text_processor(answer, remove_period=True)
|
||||||
|
|
||||||
|
if self.config.embed_from_llm:
|
||||||
|
if self.config.llm_family in ['llama', 'mistral']:
|
||||||
|
cls_tok = ''
|
||||||
|
sep_tok = ''
|
||||||
|
bos_tok = '<s>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '<s>'
|
||||||
|
sep_tok = '</s>'
|
||||||
|
bos_tok = '<pad>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '[CLS]'
|
||||||
|
sep_tok = '[SEP]'
|
||||||
|
bos_tok = '[SEP]'
|
||||||
|
eos_tok = '[SEP]'
|
||||||
|
|
||||||
|
# preprocess the textual data
|
||||||
|
caption = cls_tok + caption + sep_tok
|
||||||
|
history = sep_tok.join(history)
|
||||||
|
history = history + sep_tok
|
||||||
|
# if self.config.llm_family == 'flan_t5':
|
||||||
|
# answer = '<s> ' + self.text_processor(answer) + ' </s>'
|
||||||
|
# else:
|
||||||
|
# answer = self.text_processor(answer) + eos_tok
|
||||||
|
|
||||||
|
return vis, caption, history, answer
|
||||||
|
|
||||||
|
|
||||||
|
# def collate_fn(self, batch):
|
||||||
|
|
||||||
|
# BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['<s>', '</s>', '</s>'])
|
||||||
|
|
||||||
|
# vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], []
|
||||||
|
# batch_size = len(batch)
|
||||||
|
# for b in batch:
|
||||||
|
# vis_list.append(b[0])
|
||||||
|
# cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS]
|
||||||
|
# cap_list.append(torch.tensor(cap))
|
||||||
|
# if len(b[2])!=0:
|
||||||
|
# hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]]
|
||||||
|
# hist_list.append(torch.tensor(list(chain(*hist))))
|
||||||
|
# else:
|
||||||
|
# hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||||
|
# hist_list.append(torch.tensor(hist))
|
||||||
|
|
||||||
|
# ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||||
|
# ques_list.append(torch.tensor(ques))
|
||||||
|
# ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS]
|
||||||
|
# ans_list.append(torch.tensor(ans))
|
||||||
|
# index_list.append(b[5])
|
||||||
|
# vid_id_list.append(b[6])
|
||||||
|
|
||||||
|
# # pad and keep track of the original lengths
|
||||||
|
# cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id)
|
||||||
|
# hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id)
|
||||||
|
# ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id)
|
||||||
|
# ans_input_ids, _ = self.padding(ans_list, -100)
|
||||||
|
|
||||||
|
# cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id
|
||||||
|
# hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id
|
||||||
|
# ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id
|
||||||
|
|
||||||
|
# total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)]
|
||||||
|
# max_len = max(total_orig_lens)
|
||||||
|
|
||||||
|
# dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id)
|
||||||
|
# enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool)
|
||||||
|
# for i, l in enumerate(total_orig_lens):
|
||||||
|
# enc_dec_attention_mask[i][:l] = True
|
||||||
|
# # add the masking of the visual input
|
||||||
|
# num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])]
|
||||||
|
# if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']:
|
||||||
|
# vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries
|
||||||
|
# else:
|
||||||
|
# vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries
|
||||||
|
|
||||||
|
# enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1)
|
||||||
|
# # Now prepare the data
|
||||||
|
# vis = torch.stack(vis_list, dim=0)
|
||||||
|
# cap = {
|
||||||
|
# 'input_ids': cap_input_ids,
|
||||||
|
# 'attention_mask': cap_attention_mask,
|
||||||
|
# 'orig_lens': cap_orig_lens
|
||||||
|
# }
|
||||||
|
|
||||||
|
# hist = {
|
||||||
|
# 'input_ids': hist_input_ids,
|
||||||
|
# 'attention_mask': hist_attention_mask,
|
||||||
|
# 'orig_lens': hist_orig_lens
|
||||||
|
# }
|
||||||
|
|
||||||
|
# ques = {
|
||||||
|
# 'input_ids': ques_input_ids,
|
||||||
|
# 'attention_mask': ques_attention_mask,
|
||||||
|
# 'orig_lens': ques_orig_lens
|
||||||
|
# }
|
||||||
|
|
||||||
|
# ans = {
|
||||||
|
# 'input_ids': ans_input_ids,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# enc_dec_input = {
|
||||||
|
# 'input_ids': dummy_input_ids_enc_dec,
|
||||||
|
# 'attention_mask': enc_dec_attention_mask,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# index = torch.tensor(index_list)
|
||||||
|
# return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_champagne_dataset(config, vis_processor, text_processor, split):
|
||||||
|
dataset = Champagne(config, 'champagne', vis_processor, text_processor, split)
|
||||||
|
return dataset
|
137
datasets/dataloader.py
Normal file
137
datasets/dataloader.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
"""
|
||||||
|
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
||||||
|
import torch.distributed as dist
|
||||||
|
from utils.dist import *
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MetaLoader(object):
|
||||||
|
""" wraps multiple data loader """
|
||||||
|
def __init__(self, name2loader):
|
||||||
|
"""Iterates over multiple dataloaders, it ensures all processes
|
||||||
|
work on data from the same dataloader. This loader will end when
|
||||||
|
the shorter dataloader raises StopIteration exception.
|
||||||
|
|
||||||
|
loaders: Dict, {name: dataloader}
|
||||||
|
"""
|
||||||
|
self.name2loader = name2loader
|
||||||
|
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
|
||||||
|
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
|
||||||
|
index2name = {v: k for k, v in name2index.items()}
|
||||||
|
|
||||||
|
iter_order = []
|
||||||
|
for n, l in name2loader.items():
|
||||||
|
iter_order.extend([name2index[n]]*len(l))
|
||||||
|
|
||||||
|
random.shuffle(iter_order)
|
||||||
|
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
|
||||||
|
|
||||||
|
# sync
|
||||||
|
if is_dist_avail_and_initialized():
|
||||||
|
# make sure all processes have the same order so that
|
||||||
|
# each step they will have data from the same loader
|
||||||
|
dist.broadcast(iter_order, src=0)
|
||||||
|
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
|
||||||
|
|
||||||
|
logger.info(str(self))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
|
||||||
|
for idx, (name, loader) in enumerate(self.name2loader.items()):
|
||||||
|
output.append(
|
||||||
|
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
|
||||||
|
)
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.iter_order)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
""" this iterator will run indefinitely """
|
||||||
|
for name in self.iter_order:
|
||||||
|
_iter = self.name2iter[name]
|
||||||
|
batch = next(_iter)
|
||||||
|
yield name, batch
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataloaders(config, datasets, split, output_dict=False):
|
||||||
|
if isinstance(datasets, dict):
|
||||||
|
datasets = list(datasets.values())
|
||||||
|
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
|
||||||
|
if config['distributed'] and split != 'test':
|
||||||
|
num_tasks = get_world_size()
|
||||||
|
global_rank = get_rank()
|
||||||
|
samplers = create_samplers(
|
||||||
|
datasets, shuffles, num_tasks, global_rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
samplers = [None] * len(datasets)
|
||||||
|
|
||||||
|
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
|
||||||
|
collate_fns = []
|
||||||
|
for dataset in datasets:
|
||||||
|
if isinstance(dataset, ConcatDataset):
|
||||||
|
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
|
||||||
|
else:
|
||||||
|
collate_fns.append(getattr(dataset, 'collate_fn', None))
|
||||||
|
|
||||||
|
loaders = create_loader(
|
||||||
|
datasets,
|
||||||
|
samplers,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=[config.num_workers] * len(datasets),
|
||||||
|
is_trains=shuffles,
|
||||||
|
collate_fns=collate_fns,
|
||||||
|
) # [0]
|
||||||
|
loaders_dict = {}
|
||||||
|
if output_dict:
|
||||||
|
for l in loaders:
|
||||||
|
if isinstance(l.dataset, ConcatDataset):
|
||||||
|
loaders_dict[l.dataset.datasets[0].medium] = l
|
||||||
|
else:
|
||||||
|
loaders_dict[l.dataset.medium] = l
|
||||||
|
return loaders_dict
|
||||||
|
return loaders
|
||||||
|
|
||||||
|
|
||||||
|
def create_samplers(datasets, shuffles, num_tasks, global_rank):
|
||||||
|
samplers = []
|
||||||
|
for dataset, shuffle in zip(datasets, shuffles):
|
||||||
|
sampler = torch.utils.data.DistributedSampler(
|
||||||
|
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
||||||
|
)
|
||||||
|
samplers.append(sampler)
|
||||||
|
return samplers
|
||||||
|
|
||||||
|
|
||||||
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||||
|
loaders = []
|
||||||
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
||||||
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
||||||
|
):
|
||||||
|
if is_train:
|
||||||
|
shuffle = sampler is None
|
||||||
|
drop_last = True
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
|
drop_last = True
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=bs,
|
||||||
|
num_workers=n_worker,
|
||||||
|
pin_memory=False,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=shuffle,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
drop_last=drop_last,
|
||||||
|
persistent_workers=True if n_worker > 0 else False,
|
||||||
|
)
|
||||||
|
loaders.append(loader)
|
||||||
|
return loaders
|
86
datasets/nextqa_dataset.py
Normal file
86
datasets/nextqa_dataset.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
# import h5py
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from .video_utils import read_frames_decord
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(file_name):
|
||||||
|
annos = None
|
||||||
|
if os.path.splitext(file_name)[-1] == '.csv':
|
||||||
|
return pd.read_csv(file_name)
|
||||||
|
with open(file_name, 'r') as fp:
|
||||||
|
if os.path.splitext(file_name)[1]== '.txt':
|
||||||
|
annos = fp.readlines()
|
||||||
|
annos = [line.rstrip() for line in annos]
|
||||||
|
if os.path.splitext(file_name)[1] == '.json':
|
||||||
|
annos = json.load(fp)
|
||||||
|
return annos
|
||||||
|
|
||||||
|
|
||||||
|
class NextQADataset(Dataset):
|
||||||
|
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.medium = medium
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
with open(config['vid_mapping_nextqa'], 'r') as f:
|
||||||
|
self.video_mapping = json.load(f)
|
||||||
|
|
||||||
|
self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)])
|
||||||
|
|
||||||
|
if split == 'test':
|
||||||
|
self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']]
|
||||||
|
self.captions = load_file(self.config['next_qa_captions_{}'.format(split)])
|
||||||
|
else:
|
||||||
|
self.captions = None
|
||||||
|
|
||||||
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||||
|
if num_samples > 0:
|
||||||
|
self.sample_list = self.sample_list[:num_samples]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sample_list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_vid(self, vid_id):
|
||||||
|
vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4')
|
||||||
|
|
||||||
|
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
||||||
|
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
||||||
|
|
||||||
|
vis = torch.cat(frames, dim=0)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.split == 'test':
|
||||||
|
idx += self.config['start_idx_gen']
|
||||||
|
|
||||||
|
cur_sample = self.sample_list.loc[idx]
|
||||||
|
video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||||
|
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||||
|
|
||||||
|
history = self.text_processor(ques)
|
||||||
|
answer = self.text_processor(ans)
|
||||||
|
if self.split == 'test':
|
||||||
|
caption = self.text_processor(self.captions[video_id])
|
||||||
|
else:
|
||||||
|
caption = self.text_processor('please answer the following question based on the video')
|
||||||
|
vis = self.load_vid(video_id)
|
||||||
|
|
||||||
|
return vis, caption, history, answer, video_id, qid
|
||||||
|
|
||||||
|
def load_nextqa_dataset(config, vis_processor, text_processor, split):
|
||||||
|
# data_file = config['anno_avsd_{}'.format(split)]
|
||||||
|
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
||||||
|
dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split)
|
||||||
|
return dataset
|
156
datasets/pretraining.py
Normal file
156
datasets/pretraining.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from torchvision import transforms
|
||||||
|
import random
|
||||||
|
|
||||||
|
from .utils import pre_text, type_transform_helper, load_anno, open_img
|
||||||
|
|
||||||
|
class CapDataset(Dataset):
|
||||||
|
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||||
|
super(CapDataset, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||||
|
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split # train / val / test
|
||||||
|
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
|
||||||
|
# get the mapping between caption and image/video
|
||||||
|
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||||
|
with open(mapping_path, 'rb') as f:
|
||||||
|
self.mapping = pickle.load(f)
|
||||||
|
|
||||||
|
# These are the main ids of the dataset (typically one pro image/vid)
|
||||||
|
self.ids = list(self.mapping.keys())
|
||||||
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||||
|
if num_samples > 0:
|
||||||
|
self.ids = self.ids[:num_samples]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.ids)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
item = self.mapping[self.ids[index]]
|
||||||
|
# _id = self.ids[index]
|
||||||
|
############################# Textal features #############################
|
||||||
|
caption = item['caption']
|
||||||
|
# caption_ = pre_text(caption)
|
||||||
|
caption = self.text_processor(caption)
|
||||||
|
# add [CLS] token
|
||||||
|
caption = '[CLS] ' + caption
|
||||||
|
|
||||||
|
if self.medium == 'cc3m':
|
||||||
|
pth = os.path.join(self.root_vis, item['file'])
|
||||||
|
vis = open_img(pth)
|
||||||
|
vis = self.vis_processor(vis).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
pth = os.path.join(self.root_vis, item['file'])
|
||||||
|
f_names = os.listdir(pth)
|
||||||
|
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||||
|
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||||
|
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||||
|
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||||
|
vis = torch.cat(vis, dim=0)
|
||||||
|
|
||||||
|
# Get negative vis
|
||||||
|
neg_index = random.randint(0, len(self) - 1)
|
||||||
|
while neg_index == index:
|
||||||
|
neg_index = random.randint(0, len(self) - 1)
|
||||||
|
|
||||||
|
neg_item = self.mapping[self.ids[neg_index]]
|
||||||
|
|
||||||
|
if self.medium == 'cc3m':
|
||||||
|
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||||
|
neg_vis = open_img(neg_pth)
|
||||||
|
neg_vis = self.vis_processor(neg_vis).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||||
|
neg_f_names = os.listdir(neg_pth)
|
||||||
|
neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||||
|
neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names]
|
||||||
|
neg_vis = [Image.open(p).convert('RGB') for p in neg_pth]
|
||||||
|
neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis]
|
||||||
|
neg_vis = torch.cat(neg_vis, dim=0)
|
||||||
|
|
||||||
|
# return caption, vis
|
||||||
|
return vis, caption, neg_vis
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTextRetDataset(Dataset):
|
||||||
|
def __init__(self, config, vis_processor, text_processor, medium, split):
|
||||||
|
super(VideoTextRetDataset, self).__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||||
|
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split # train / val / test
|
||||||
|
|
||||||
|
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
|
||||||
|
anno_path = config['annotation_{}_{}'.format(medium, split)]
|
||||||
|
self.raw_anno_list = load_anno(anno_path)
|
||||||
|
self.text = []
|
||||||
|
self.vis = []
|
||||||
|
self.txt2vis = {}
|
||||||
|
self.vis2txt = {}
|
||||||
|
self.build_data()
|
||||||
|
self.anno_list = [dict(vis=v) for v in self.vis]
|
||||||
|
# print('bla')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.anno_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
pth = self.anno_list[index]['vis']
|
||||||
|
f_names = os.listdir(pth)
|
||||||
|
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||||
|
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||||
|
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||||
|
vis = [self.vis_processor(v) for v in vis]
|
||||||
|
# vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
||||||
|
vis = torch.cat(vis, dim=0)
|
||||||
|
# vis = self.trans(vis)
|
||||||
|
|
||||||
|
return vis, index
|
||||||
|
|
||||||
|
def build_data(self):
|
||||||
|
"""each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
|
||||||
|
txt_id = 0
|
||||||
|
for vis_id, ann in enumerate(self.raw_anno_list):
|
||||||
|
self.vis.append(ann["vis"])
|
||||||
|
self.vis2txt[vis_id] = []
|
||||||
|
_captions = ann["caption"] \
|
||||||
|
if isinstance(ann["caption"], list) else [ann["caption"], ]
|
||||||
|
for i, caption in enumerate(_captions):
|
||||||
|
# self.text.append(pre_text(caption))
|
||||||
|
self.text.append(self.text_processor(caption))
|
||||||
|
self.vis2txt[vis_id].append(txt_id)
|
||||||
|
self.txt2vis[txt_id] = vis_id
|
||||||
|
txt_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
def load_datasets(config, vis_processor, text_processor, split):
|
||||||
|
if config['stage'] == 'stage_1':
|
||||||
|
if split != 'test':
|
||||||
|
cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split)
|
||||||
|
webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split)
|
||||||
|
datasets = {
|
||||||
|
'cc3m': cc3m_dataset,
|
||||||
|
'webvid': webvid_dataset
|
||||||
|
}
|
||||||
|
else: # Test with msrvtt_1k --> video retieval
|
||||||
|
msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split)
|
||||||
|
datasets = {
|
||||||
|
'msrvtt': msrvtt_dataset
|
||||||
|
}
|
||||||
|
return datasets
|
83
datasets/utils.py
Normal file
83
datasets/utils.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from tqdm import trange
|
||||||
|
from utils.dist import is_main_process
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def open_img(img_pth):
|
||||||
|
try:
|
||||||
|
img = Image.open(img_pth).convert('RGB')
|
||||||
|
return img
|
||||||
|
except:
|
||||||
|
img = np.random.randint(0, high=256, size=(224,224, 3))
|
||||||
|
img = Image.fromarray(img, 'RGB')
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def pre_text(text, max_l=None):
|
||||||
|
text = re.sub(r"(['!?\"()*#:;~])", '', text.lower())
|
||||||
|
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
||||||
|
|
||||||
|
text = re.sub(r"\s{2,}", ' ', text)
|
||||||
|
text = text.rstrip('\n').strip(' ')
|
||||||
|
|
||||||
|
if max_l: # truncate
|
||||||
|
words = text.split(' ')
|
||||||
|
if len(words) > max_l:
|
||||||
|
text = ' '.join(words[:max_l])
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets_media(dataloaders):
|
||||||
|
media = {}
|
||||||
|
for dataloader in dataloaders:
|
||||||
|
if isinstance(dataloader.dataset, ConcatDataset):
|
||||||
|
media[dataloader.dataset.datasets[0].medium] = dataloader
|
||||||
|
else:
|
||||||
|
media[dataloader.dataset.medium] = dataloader
|
||||||
|
|
||||||
|
# media = [dataloader.dataset.medium for dataloader in dataloaders]
|
||||||
|
return media
|
||||||
|
|
||||||
|
def type_transform_helper(x):
|
||||||
|
return x.float().div(255.0)
|
||||||
|
|
||||||
|
def load_anno(ann_file_list):
|
||||||
|
"""[summary]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file_list (List[List[str, str]] or List[str, str]):
|
||||||
|
the latter will be automatically converted to the former.
|
||||||
|
Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
|
||||||
|
which specifies the data type, video or image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List(dict): each dict is {
|
||||||
|
image: str or List[str], # image_path,
|
||||||
|
caption: str or List[str] # caption text string
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if isinstance(ann_file_list[0], str):
|
||||||
|
ann_file_list = [ann_file_list]
|
||||||
|
|
||||||
|
ann = []
|
||||||
|
for d in ann_file_list:
|
||||||
|
data_root = d[1]
|
||||||
|
fp = d[0]
|
||||||
|
is_video = len(d) == 3 and d[2] == "video"
|
||||||
|
cur_ann = json.load(open(fp, "r"))
|
||||||
|
iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
|
||||||
|
if is_main_process() else range(len(cur_ann))
|
||||||
|
for idx in iterator:
|
||||||
|
key = "video" if is_video else "image"
|
||||||
|
video_id = cur_ann[idx][key][5:].split('.')[0]
|
||||||
|
# unified to have the same key for data path
|
||||||
|
# if isinstance(cur_ann[idx][key], str):
|
||||||
|
cur_ann[idx]["vis"] = os.path.join(data_root, video_id)
|
||||||
|
# else: # list
|
||||||
|
# cur_ann[idx]["vis"] = [os.path.join(data_root, e) for e in cur_ann[idx][key]]
|
||||||
|
ann += cur_ann
|
||||||
|
return ann
|
97
datasets/video_utils.py
Normal file
97
datasets/video_utils.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
"""
|
||||||
|
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
import decord
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
|
|
||||||
|
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
|
||||||
|
"""
|
||||||
|
Converts a present time with the given time base and start_pts offset to seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
time_in_seconds (float): The corresponding time in seconds.
|
||||||
|
|
||||||
|
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
|
||||||
|
"""
|
||||||
|
if pts == math.inf:
|
||||||
|
return math.inf
|
||||||
|
|
||||||
|
return int(pts - start_pts) * time_base
|
||||||
|
|
||||||
|
|
||||||
|
def get_pyav_video_duration(video_reader):
|
||||||
|
video_stream = video_reader.streams.video[0]
|
||||||
|
video_duration = pts_to_secs(
|
||||||
|
video_stream.duration,
|
||||||
|
video_stream.time_base,
|
||||||
|
video_stream.start_time
|
||||||
|
)
|
||||||
|
return float(video_duration)
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame_indices_by_fps():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
||||||
|
if sample in ["rand", "middle"]:
|
||||||
|
acc_samples = min(num_frames, vlen)
|
||||||
|
# split the video into `acc_samples` intervals, and sample from each interval.
|
||||||
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
||||||
|
ranges = []
|
||||||
|
for idx, interv in enumerate(intervals[:-1]):
|
||||||
|
ranges.append((interv, intervals[idx + 1] - 1))
|
||||||
|
if sample == 'rand':
|
||||||
|
try:
|
||||||
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
||||||
|
except:
|
||||||
|
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
||||||
|
frame_indices.sort()
|
||||||
|
frame_indices = list(frame_indices)
|
||||||
|
elif fix_start is not None:
|
||||||
|
frame_indices = [x[0] + fix_start for x in ranges]
|
||||||
|
elif sample == 'middle':
|
||||||
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if len(frame_indices) < num_frames: # padded with last frame
|
||||||
|
padded_frame_indices = [frame_indices[-1]] * num_frames
|
||||||
|
padded_frame_indices[:len(frame_indices)] = frame_indices
|
||||||
|
frame_indices = padded_frame_indices
|
||||||
|
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
||||||
|
output_fps = float(sample[3:])
|
||||||
|
duration = float(vlen) / input_fps
|
||||||
|
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
||||||
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
||||||
|
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
||||||
|
frame_indices = [e for e in frame_indices if e < vlen]
|
||||||
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
||||||
|
frame_indices = frame_indices[:max_num_frames]
|
||||||
|
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
return frame_indices
|
||||||
|
|
||||||
|
|
||||||
|
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
|
||||||
|
video_reader = decord.VideoReader(video_path, num_threads=1)
|
||||||
|
vlen = len(video_reader)
|
||||||
|
fps = video_reader.get_avg_fps()
|
||||||
|
duration = vlen / float(fps)
|
||||||
|
frame_indices = get_frame_indices(
|
||||||
|
num_frames, vlen, sample=sample, fix_start=fix_start,
|
||||||
|
input_fps=fps, max_num_frames=max_num_frames
|
||||||
|
)
|
||||||
|
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
|
||||||
|
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
|
||||||
|
frames = frames.split(1, dim=0)
|
||||||
|
|
||||||
|
frames = [Image.fromarray(f.squeeze().numpy(), mode='RGB') for f in frames]
|
||||||
|
# frames = frames.numpy() # convert to numpy
|
||||||
|
return frames, frame_indices, duration
|
183
datasets/visdial_dataset.py
Normal file
183
datasets/visdial_dataset.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
# coding: utf-8
|
||||||
|
# author: noctli
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from itertools import chain
|
||||||
|
from torchvision import transforms
|
||||||
|
from .utils import type_transform_helper
|
||||||
|
from .utils import open_img
|
||||||
|
|
||||||
|
def tokenize(text, tokenizer, return_tensor=False):
|
||||||
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
if return_tensor:
|
||||||
|
return torch.tensor(tokenized_text).long()
|
||||||
|
return tokenized_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config, split):
|
||||||
|
|
||||||
|
dialog_pth = config['anno_visdial_{}'.format(split)]
|
||||||
|
dialog_data = json.load(open(dialog_pth, 'r'))['data']
|
||||||
|
|
||||||
|
all_answers = dialog_data['answers']
|
||||||
|
all_questions = dialog_data['questions']
|
||||||
|
dialog_list = []
|
||||||
|
n_history = config['num_hist_turns_visdial']
|
||||||
|
vid_set = set()
|
||||||
|
undisclosed_only = False
|
||||||
|
|
||||||
|
pbar = tqdm(dialog_data['dialogs'])
|
||||||
|
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
|
||||||
|
for dialog in pbar:
|
||||||
|
caption = dialog['caption'] + ' .'
|
||||||
|
questions = [all_questions[d['question']] + ' ?' for d in dialog['dialog']]
|
||||||
|
answers = [all_answers[d['answer']] + ' .' for d in dialog['dialog']]
|
||||||
|
|
||||||
|
# answer_opts = [[all_answers[key] for key in d['answer_options']] for d in dialog['dialog']]
|
||||||
|
# if 'test' in config['anno_visdial_{}'.format(split)]:
|
||||||
|
# gt_indices = [-1 for _ in range(len(questions))]
|
||||||
|
# else:
|
||||||
|
# gt_indices = [d['gt_index'] for d in dialog['dialog']]
|
||||||
|
|
||||||
|
vid = dialog['image_id']
|
||||||
|
vid_set.add(vid)
|
||||||
|
if undisclosed_only:
|
||||||
|
it = range(len(questions) - 1, len(questions))
|
||||||
|
else:
|
||||||
|
it = range(len(questions))
|
||||||
|
|
||||||
|
qalist=[]
|
||||||
|
history = []
|
||||||
|
if undisclosed_only:
|
||||||
|
for n in range(len(questions)-1):
|
||||||
|
qalist.append(questions[n])
|
||||||
|
qalist.append(answers[n])
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
|
||||||
|
for n in it:
|
||||||
|
if undisclosed_only:
|
||||||
|
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||||
|
question = questions[n]
|
||||||
|
answer = answers[n]
|
||||||
|
# answer_opt = answer_opts[n]
|
||||||
|
# gt_index = gt_indices[n]
|
||||||
|
history.append(question)
|
||||||
|
# if n_history == 0:
|
||||||
|
# item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
|
||||||
|
# else:
|
||||||
|
# item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
|
||||||
|
|
||||||
|
if n_history == 0:
|
||||||
|
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1}
|
||||||
|
else:
|
||||||
|
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dialog_list.append(item)
|
||||||
|
qalist.append(question)
|
||||||
|
qalist.append(answer)
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
|
||||||
|
return dialog_list
|
||||||
|
|
||||||
|
|
||||||
|
class VisDial(Dataset):
|
||||||
|
def __init__(self, config, medium, vis_processor, text_processor, split
|
||||||
|
# tokenizer, features=None, drop_rate=0.0, train=True
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.medium = medium
|
||||||
|
self.split = split
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||||
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||||
|
|
||||||
|
self.dialogs = get_dataset(config, split)
|
||||||
|
|
||||||
|
if split == 'test':
|
||||||
|
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
|
||||||
|
|
||||||
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||||
|
if num_samples > 0:
|
||||||
|
self.dialogs = self.dialogs[:num_samples]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dialogs)
|
||||||
|
|
||||||
|
def load_img(self, vid_id):
|
||||||
|
file_pth = os.path.join(self.root_vis, f'{vid_id}.jpg')
|
||||||
|
vis = open_img(file_pth)
|
||||||
|
vis = self.vis_processor(vis).unsqueeze(0)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dialog = self.dialogs[index]
|
||||||
|
|
||||||
|
vid_id = dialog['vid']
|
||||||
|
caption = dialog['caption']
|
||||||
|
history = dialog['history']
|
||||||
|
answer = dialog['answer']
|
||||||
|
d_round = dialog['round']
|
||||||
|
|
||||||
|
caption = self.text_processor(caption)
|
||||||
|
history = [self.text_processor(h) for h in history]
|
||||||
|
answer = self.text_processor(answer, remove_period=True)
|
||||||
|
|
||||||
|
|
||||||
|
# if self.split == 'test':
|
||||||
|
# answer_opts = dialog['answer_opts']
|
||||||
|
# answer_opts = [self.text_processor(a) for a in answer_opts]
|
||||||
|
|
||||||
|
# gt_index = dialog['gt_index']
|
||||||
|
# dialog_round = dialog['round']
|
||||||
|
|
||||||
|
# dense_key = str(vid_id) + '_' + str(dialog_round)
|
||||||
|
# gt_relevance = self.dense_annos.get(dense_key, -1)
|
||||||
|
# # eval_data = (answer_opts, gt_index, gt_relevance)
|
||||||
|
|
||||||
|
|
||||||
|
if self.config.embed_from_llm:
|
||||||
|
if self.config.llm_family in ['llama', 'mistral']:
|
||||||
|
cls_tok = ''
|
||||||
|
sep_tok = ' '
|
||||||
|
bos_tok = '<s>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '<s>'
|
||||||
|
sep_tok = '</s>'
|
||||||
|
bos_tok = '<pad>'
|
||||||
|
eos_tok = '</s>'
|
||||||
|
else:
|
||||||
|
cls_tok = '[CLS]'
|
||||||
|
sep_tok = '[SEP]'
|
||||||
|
bos_tok = '[SEP]'
|
||||||
|
eos_tok = '[SEP]'
|
||||||
|
|
||||||
|
caption = cls_tok + caption + sep_tok
|
||||||
|
history = sep_tok.join(history)
|
||||||
|
history = history + sep_tok
|
||||||
|
|
||||||
|
# load the video frames
|
||||||
|
vis = self.load_img(vid_id)
|
||||||
|
|
||||||
|
# if self.split == 'test':
|
||||||
|
# return vis, caption, history, answer, vid_id, answer_opts, gt_relevance, gt_index
|
||||||
|
|
||||||
|
# else:
|
||||||
|
return vis, caption, history, answer, vid_id, d_round
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_visdial_dataset(config, vis_processor, text_processor, split):
|
||||||
|
dataset = VisDial(config, 'visdial', vis_processor, text_processor, split)
|
||||||
|
return dataset
|
BIN
emergency/item.pkl
Normal file
BIN
emergency/item.pkl
Normal file
Binary file not shown.
81
eval_visdial.py
Normal file
81
eval_visdial.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from utils.metrcis import SparseGTMetrics, NDCG
|
||||||
|
from Levenshtein import ratio
|
||||||
|
|
||||||
|
|
||||||
|
output_dir = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial'
|
||||||
|
|
||||||
|
file_paths = os.listdir(output_dir)
|
||||||
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||||
|
name = file_paths[0]
|
||||||
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
count = 0
|
||||||
|
for pth in file_paths:
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
partial_res = json.load(f)
|
||||||
|
count += len(partial_res)
|
||||||
|
results.update(partial_res)
|
||||||
|
# dialogs.extend(data['dialogs'])
|
||||||
|
os.remove(pth)
|
||||||
|
|
||||||
|
|
||||||
|
name = "".join(name.split('-')[:-1]) + '.json'
|
||||||
|
output_path = os.path.join(output_dir, name)
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
# result_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial/zeroshot_visdial_after_champagne_googleflant5large_results_dstc8_beam_depth_8_lenPen_0.3.json'
|
||||||
|
|
||||||
|
# with open(result_path, 'r') as f:
|
||||||
|
# results = json.load(f)
|
||||||
|
|
||||||
|
annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json'
|
||||||
|
dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json'
|
||||||
|
|
||||||
|
with open(annos_path, 'r') as f:
|
||||||
|
data = json.load(f)['data']
|
||||||
|
|
||||||
|
all_answers = data['answers']
|
||||||
|
all_questions = data['questions']
|
||||||
|
|
||||||
|
|
||||||
|
dialogs = data['dialogs']
|
||||||
|
|
||||||
|
dialogs_dict = {}
|
||||||
|
|
||||||
|
for dialog in dialogs:
|
||||||
|
image_id = dialog['image_id']
|
||||||
|
for i, turn in enumerate(dialog['dialog']):
|
||||||
|
answer_opts = [all_answers[a] for a in turn['answer_options']]
|
||||||
|
dialogs_dict[str(image_id) + '_' + str(i+1)] = {
|
||||||
|
'answer_opts': answer_opts,
|
||||||
|
'gt_index': turn['gt_index']
|
||||||
|
}
|
||||||
|
# print('bla')
|
||||||
|
|
||||||
|
with open(dense_annos_path, 'r') as f:
|
||||||
|
dense_data = json.load(f)
|
||||||
|
|
||||||
|
dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data}
|
||||||
|
|
||||||
|
sparse_metrics = SparseGTMetrics()
|
||||||
|
ndcg = NDCG()
|
||||||
|
|
||||||
|
for res_key, res in results.items():
|
||||||
|
answer_opts = dialogs_dict[res_key]['answer_opts']
|
||||||
|
gt_index = torch.tensor(dialogs_dict[res_key]['gt_index'])
|
||||||
|
|
||||||
|
scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0)
|
||||||
|
sparse_metrics.observe(scores, gt_index)
|
||||||
|
if res_key in dense_data:
|
||||||
|
gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0)
|
||||||
|
ndcg.observe(scores.squeeze(0), gt_relevance)
|
||||||
|
# print('bla')
|
||||||
|
|
||||||
|
print(sparse_metrics.retrieve())
|
||||||
|
print(ndcg.retrieve())
|
273
eval_visdial_sentence_embeddings.py
Normal file
273
eval_visdial_sentence_embeddings.py
Normal file
|
@ -0,0 +1,273 @@
|
||||||
|
from sentence_transformers.cross_encoder import CrossEncoder
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def scores_to_ranks(scores: torch.Tensor):
|
||||||
|
"""Convert model output scores into ranks."""
|
||||||
|
batch_size, num_rounds, num_options = scores.size()
|
||||||
|
scores = scores.view(-1, num_options)
|
||||||
|
|
||||||
|
# sort in descending order - largest score gets highest rank
|
||||||
|
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
|
||||||
|
|
||||||
|
# i-th position in ranked_idx specifies which score shall take this
|
||||||
|
# position but we want i-th position to have rank of score at that
|
||||||
|
# position, do this conversion
|
||||||
|
ranks = ranked_idx.clone().fill_(0)
|
||||||
|
for i in range(ranked_idx.size(0)):
|
||||||
|
for j in range(num_options):
|
||||||
|
ranks[i][ranked_idx[i][j]] = j
|
||||||
|
# convert from 0-99 ranks to 1-100 ranks
|
||||||
|
ranks += 1
|
||||||
|
ranks = ranks.view(batch_size, num_rounds, num_options)
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
|
||||||
|
class SparseGTMetrics(object):
|
||||||
|
"""
|
||||||
|
A class to accumulate all metrics with sparse ground truth annotations.
|
||||||
|
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._rank_list = []
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||||
|
):
|
||||||
|
predicted_scores = predicted_scores.detach()
|
||||||
|
|
||||||
|
# shape: (batch_size, num_rounds, num_options)
|
||||||
|
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||||
|
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||||
|
|
||||||
|
# collapse batch dimension
|
||||||
|
predicted_ranks = predicted_ranks.view(
|
||||||
|
batch_size * num_rounds, num_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
predicted_gt_ranks = predicted_ranks[
|
||||||
|
torch.arange(batch_size * num_rounds), target_ranks
|
||||||
|
]
|
||||||
|
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
|
||||||
|
|
||||||
|
def retrieve(self, reset: bool = True):
|
||||||
|
num_examples = len(self._rank_list)
|
||||||
|
if num_examples > 0:
|
||||||
|
# convert to numpy array for easy calculation.
|
||||||
|
__rank_list = torch.tensor(self._rank_list).float()
|
||||||
|
metrics = {
|
||||||
|
"r@1": torch.mean((__rank_list <= 1).float()).item(),
|
||||||
|
"r@5": torch.mean((__rank_list <= 5).float()).item(),
|
||||||
|
"r@10": torch.mean((__rank_list <= 10).float()).item(),
|
||||||
|
"mean": torch.mean(__rank_list).item(),
|
||||||
|
"mrr": torch.mean(__rank_list.reciprocal()).item(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
self.reset()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._rank_list = []
|
||||||
|
|
||||||
|
|
||||||
|
class NDCG(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._ndcg_numerator = 0.0
|
||||||
|
self._ndcg_denominator = 0.0
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Observe model output scores and target ground truth relevance and
|
||||||
|
accumulate NDCG metric.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
predicted_scores: torch.Tensor
|
||||||
|
A tensor of shape (batch_size, num_options), because dense
|
||||||
|
annotations are available for 1 randomly picked round out of 10.
|
||||||
|
target_relevance: torch.Tensor
|
||||||
|
A tensor of shape same as predicted scores, indicating ground truth
|
||||||
|
relevance of each answer option for a particular round.
|
||||||
|
"""
|
||||||
|
predicted_scores = predicted_scores.detach()
|
||||||
|
|
||||||
|
# shape: (batch_size, 1, num_options)
|
||||||
|
predicted_scores = predicted_scores.unsqueeze(1)
|
||||||
|
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||||
|
|
||||||
|
# shape: (batch_size, num_options)
|
||||||
|
predicted_ranks = predicted_ranks.squeeze(1)
|
||||||
|
batch_size, num_options = predicted_ranks.size()
|
||||||
|
|
||||||
|
k = torch.sum(target_relevance != 0, dim=-1)
|
||||||
|
|
||||||
|
# shape: (batch_size, num_options)
|
||||||
|
_, rankings = torch.sort(predicted_ranks, dim=-1)
|
||||||
|
# Sort relevance in descending order so highest relevance gets top rnk.
|
||||||
|
_, best_rankings = torch.sort(
|
||||||
|
target_relevance, dim=-1, descending=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape: (batch_size, )
|
||||||
|
batch_ndcg = []
|
||||||
|
for batch_index in range(batch_size):
|
||||||
|
|
||||||
|
num_relevant = k[batch_index]
|
||||||
|
dcg = self._dcg(
|
||||||
|
rankings[batch_index][:num_relevant],
|
||||||
|
target_relevance[batch_index],
|
||||||
|
)
|
||||||
|
best_dcg = self._dcg(
|
||||||
|
best_rankings[batch_index][:num_relevant],
|
||||||
|
target_relevance[batch_index],
|
||||||
|
)
|
||||||
|
batch_ndcg.append(dcg / best_dcg)
|
||||||
|
|
||||||
|
self._ndcg_denominator += batch_size
|
||||||
|
self._ndcg_numerator += sum(batch_ndcg)
|
||||||
|
|
||||||
|
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
|
||||||
|
sorted_relevance = relevance[rankings].cpu().float()
|
||||||
|
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
|
||||||
|
return torch.sum(sorted_relevance / discounts, dim=-1)
|
||||||
|
|
||||||
|
def retrieve(self, reset: bool = True):
|
||||||
|
if self._ndcg_denominator > 0:
|
||||||
|
metrics = {
|
||||||
|
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
self.reset()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._ndcg_numerator = 0.0
|
||||||
|
self._ndcg_denominator = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json'
|
||||||
|
with open(annos_path, 'r') as f:
|
||||||
|
data = json.load(f)['data']
|
||||||
|
|
||||||
|
dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json'
|
||||||
|
with open(dense_annos_path, 'r') as f:
|
||||||
|
dense_data = json.load(f)
|
||||||
|
|
||||||
|
dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data}
|
||||||
|
|
||||||
|
results_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial_before_supplementary/zeroshot_visdial_after_avsd_4_frames_3_rounds_ft_fp16_googleflant5large_results_dstc10_beam_depth_4_lenPen_0.3.json'
|
||||||
|
with open(results_path, 'r') as f:
|
||||||
|
results = json.load(f)
|
||||||
|
|
||||||
|
all_answers = data['answers']
|
||||||
|
all_questions = data['questions']
|
||||||
|
|
||||||
|
|
||||||
|
dialogs = data['dialogs']
|
||||||
|
|
||||||
|
dialogs_dict = {}
|
||||||
|
|
||||||
|
for dialog in dialogs:
|
||||||
|
image_id = dialog['image_id']
|
||||||
|
for i, turn in enumerate(dialog['dialog']):
|
||||||
|
answer_opts = [all_answers[a] for a in turn['answer_options']]
|
||||||
|
dialogs_dict[str(image_id) + '_' + str(i+1)] = {
|
||||||
|
'answer_opts': answer_opts,
|
||||||
|
'gt_index': turn['gt_index']
|
||||||
|
}
|
||||||
|
# print('bla')
|
||||||
|
|
||||||
|
sparse_metrics = SparseGTMetrics()
|
||||||
|
ndcg = NDCG()
|
||||||
|
|
||||||
|
# 1. Load a pretrained CrossEncoder model
|
||||||
|
model = CrossEncoder("cross-encoder/stsb-roberta-large")
|
||||||
|
|
||||||
|
for i, (res_key, res) in enumerate(results.items()):
|
||||||
|
print('[INFO] {} / {}'.format(i+1, len(results)))
|
||||||
|
answer_opts = dialogs_dict[res_key]['answer_opts']
|
||||||
|
gt_index = torch.tensor(dialogs_dict[res_key]['gt_index'])
|
||||||
|
gt_answer = answer_opts[gt_index]
|
||||||
|
sentence_combinations = [[res, opt] for opt in answer_opts]
|
||||||
|
scores = model.predict(sentence_combinations)
|
||||||
|
scores = torch.from_numpy(scores).unsqueeze(0).unsqueeze(0)
|
||||||
|
# scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0)
|
||||||
|
# scores = model.rank(res, answer_opts)
|
||||||
|
ranked_idx = scores_to_ranks(scores).squeeze().tolist()
|
||||||
|
new_order = np.argsort(ranked_idx)
|
||||||
|
# ranked_answers = [answer_opts[idx] for idx in new_order]
|
||||||
|
best_pick = answer_opts[new_order[0]]
|
||||||
|
sparse_metrics.observe(scores, gt_index)
|
||||||
|
if res_key in dense_data:
|
||||||
|
gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0)
|
||||||
|
ndcg.observe(scores.squeeze(0), gt_relevance)
|
||||||
|
|
||||||
|
# print('bla')
|
||||||
|
print(sparse_metrics.retrieve())
|
||||||
|
print(ndcg.retrieve())
|
||||||
|
|
||||||
|
# We want to compute the similarity between the query sentence...
|
||||||
|
# query = "A man is eating pasta."
|
||||||
|
|
||||||
|
# # ... and all sentences in the corpus
|
||||||
|
# corpus = [
|
||||||
|
# "A man is eating food.",
|
||||||
|
# "A man is eating a piece of bread.",
|
||||||
|
# "The girl is carrying a baby.",
|
||||||
|
# "A man is riding a horse.",
|
||||||
|
# "A woman is playing violin.",
|
||||||
|
# "Two men pushed carts through the woods.",
|
||||||
|
# "A man is riding a white horse on an enclosed ground.",
|
||||||
|
# "A monkey is playing drums.",
|
||||||
|
# "A cheetah is running behind its prey.",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# # 2. We rank all sentences in the corpus for the query
|
||||||
|
# ranks = model.rank(query, corpus)
|
||||||
|
|
||||||
|
# # Print the scores
|
||||||
|
# print("Query: ", query)
|
||||||
|
# for rank in ranks:
|
||||||
|
# print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")
|
||||||
|
# """
|
||||||
|
# Query: A man is eating pasta.
|
||||||
|
# 0.67 A man is eating food.
|
||||||
|
# 0.34 A man is eating a piece of bread.
|
||||||
|
# 0.08 A man is riding a horse.
|
||||||
|
# 0.07 A man is riding a white horse on an enclosed ground.
|
||||||
|
# 0.01 The girl is carrying a baby.
|
||||||
|
# 0.01 Two men pushed carts through the woods.
|
||||||
|
# 0.01 A monkey is playing drums.
|
||||||
|
# 0.01 A woman is playing violin.
|
||||||
|
# 0.01 A cheetah is running behind its prey.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# # 3. Alternatively, you can also manually compute the score between two sentences
|
||||||
|
# import numpy as np
|
||||||
|
|
||||||
|
# sentence_combinations = [[query, sentence] for sentence in corpus]
|
||||||
|
# scores = model.predict(sentence_combinations)
|
||||||
|
|
||||||
|
# # Sort the scores in decreasing order to get the corpus indices
|
||||||
|
# ranked_indices = np.argsort(scores)[::-1]
|
||||||
|
# print("Scores:", scores)
|
||||||
|
# print("Indices:", ranked_indices)
|
||||||
|
# """
|
||||||
|
# Scores: [0.6732372, 0.34102544, 0.00542465, 0.07569341, 0.00525378, 0.00536814, 0.06676237, 0.00534825, 0.00516717]
|
||||||
|
# Indices: [0 1 3 6 2 5 7 4 8]
|
||||||
|
# """
|
71
generate_parallel_avsd.sh
Executable file
71
generate_parallel_avsd.sh
Executable file
|
@ -0,0 +1,71 @@
|
||||||
|
# export MODEL=$1
|
||||||
|
# export TAG=$2
|
||||||
|
# export MODE=$3
|
||||||
|
# export EVAL_DIR=$4
|
||||||
|
# export MEDIUM=$5
|
||||||
|
# export DSTC=$6
|
||||||
|
|
||||||
|
export MODEL='v2dial/stage_3'
|
||||||
|
export TAG='finetuned_no_experts_avsd'
|
||||||
|
export MODE='generate'
|
||||||
|
export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetune_without_experts_avsd'
|
||||||
|
export DSTC=7
|
||||||
|
|
||||||
|
# >>> conda initialize >>>
|
||||||
|
# !! Contents within this block are managed by 'conda init' !!
|
||||||
|
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
eval "$__conda_setup"
|
||||||
|
else
|
||||||
|
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
. "/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
else
|
||||||
|
export PATH="/opt/anaconda3/bin:$PATH"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
unset __conda_setup
|
||||||
|
# <<< conda initialize <<<
|
||||||
|
|
||||||
|
conda activate v2dial
|
||||||
|
|
||||||
|
if [ $DSTC -eq 10 ]; then
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0112 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0112 --end_idx_gen 0224 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 0224 --end_idx_gen 0336 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 0336 --end_idx_gen 0448 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 0448 --end_idx_gen 0560 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 0560 --end_idx_gen 0672 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 0672 --end_idx_gen 0784 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 0784 --end_idx_gen 0896 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0896 --end_idx_gen 1008 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1008 --end_idx_gen 1120 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 1120 --end_idx_gen 1232 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 1232 --end_idx_gen 1344 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 1344 --end_idx_gen 1456 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 1456 --end_idx_gen 1568 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 1568 --end_idx_gen 1680 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 1680 --end_idx_gen 1804 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
else
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0107 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0107 --end_idx_gen 0214 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 0214 --end_idx_gen 0321 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 0321 --end_idx_gen 0428 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 0428 --end_idx_gen 0535 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 0535 --end_idx_gen 0642 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 0642 --end_idx_gen 0749 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 0749 --end_idx_gen 0856 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0856 --end_idx_gen 0963 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0963 --end_idx_gen 1070 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 1070 --end_idx_gen 1177 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 1177 --end_idx_gen 1284 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 1284 --end_idx_gen 1391 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 1391 --end_idx_gen 1498 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 1498 --end_idx_gen 1605 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 1605 --end_idx_gen 1710 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 00 --end_idx_gen 10 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
# export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 10 --end_idx_gen 20 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
wait
|
||||||
|
python merge_pred_avsd.py --dstc $DSTC
|
51
generate_parallel_nextqa.sh
Executable file
51
generate_parallel_nextqa.sh
Executable file
|
@ -0,0 +1,51 @@
|
||||||
|
# export MODEL=$1
|
||||||
|
# export TAG=$2
|
||||||
|
# export MODE=$3
|
||||||
|
# export EVAL_DIR=$4
|
||||||
|
|
||||||
|
export MODEL='v2dial/stage_3'
|
||||||
|
export TAG='nextqa_with_test_captions'
|
||||||
|
export MODE='generate'
|
||||||
|
export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-from_stage1_only_nextqa_after_avsd_4_frames_3_rounds_ft_fp16'
|
||||||
|
|
||||||
|
# >>> conda initialize >>>
|
||||||
|
# !! Contents within this block are managed by 'conda init' !!
|
||||||
|
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
eval "$__conda_setup"
|
||||||
|
else
|
||||||
|
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
. "/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
else
|
||||||
|
export PATH="/opt/anaconda3/bin:$PATH"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
unset __conda_setup
|
||||||
|
# <<< conda initialize <<<
|
||||||
|
|
||||||
|
conda activate v2dial
|
||||||
|
|
||||||
|
# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 10 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 10 --end_idx_gen 20 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0573 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0573 --end_idx_gen 1146 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1146 --end_idx_gen 1719 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1719 --end_idx_gen 2292 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 2292 --end_idx_gen 2865 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 2865 --end_idx_gen 3438 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 3438 --end_idx_gen 4011 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 4011 --end_idx_gen 4584 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG &
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 4584 --end_idx_gen 5157 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 5157 --end_idx_gen 5730 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 5730 --end_idx_gen 6303 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 6303 --end_idx_gen 6876 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 6876 --end_idx_gen 7449 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 7449 --end_idx_gen 8022 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 8022 --end_idx_gen 8495 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 8495 --end_idx_gen 9178 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
python merge_pred_nextqa.py
|
67
generate_parallel_visdial.sh
Executable file
67
generate_parallel_visdial.sh
Executable file
|
@ -0,0 +1,67 @@
|
||||||
|
# export MODEL=$1
|
||||||
|
# export TAG=$2
|
||||||
|
# export MODE=$3
|
||||||
|
# export EVAL_DIR=$4
|
||||||
|
# export MEDIUM=$5
|
||||||
|
# export DSTC=$6
|
||||||
|
|
||||||
|
export MODEL='v2dial/stage_3'
|
||||||
|
export TAG='finetuned_visdial_from_scratch'
|
||||||
|
export MODE='generate'
|
||||||
|
export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetuned_visdial_from_scratch/'
|
||||||
|
|
||||||
|
# >>> conda initialize >>>
|
||||||
|
# !! Contents within this block are managed by 'conda init' !!
|
||||||
|
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
eval "$__conda_setup"
|
||||||
|
else
|
||||||
|
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
. "/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
else
|
||||||
|
export PATH="/opt/anaconda3/bin:$PATH"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
unset __conda_setup
|
||||||
|
# <<< conda initialize <<<
|
||||||
|
|
||||||
|
conda activate v2dial
|
||||||
|
# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 00000 --end_idx_gen 10 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
# export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 00010 --end_idx_gen 20 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 00000 --end_idx_gen 00645 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 00645 --end_idx_gen 01290 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 01290 --end_idx_gen 01935 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 01935 --end_idx_gen 02580 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 02580 --end_idx_gen 03225 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 03225 --end_idx_gen 03870 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 03870 --end_idx_gen 04515 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 04515 --end_idx_gen 05160 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 05160 --end_idx_gen 05805 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 05805 --end_idx_gen 06450 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 06450 --end_idx_gen 07095 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 07095 --end_idx_gen 07740 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 07740 --end_idx_gen 08385 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 08385 --end_idx_gen 09030 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 09030 --end_idx_gen 09675 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 09675 --end_idx_gen 10320 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 10320 --end_idx_gen 10965 --gen_subset_num 17 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 10965 --end_idx_gen 11610 --gen_subset_num 18 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 11610 --end_idx_gen 12255 --gen_subset_num 19 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 12255 --end_idx_gen 12900 --gen_subset_num 20 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 12900 --end_idx_gen 13545 --gen_subset_num 21 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 13545 --end_idx_gen 14190 --gen_subset_num 22 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 14190 --end_idx_gen 14835 --gen_subset_num 23 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 14835 --end_idx_gen 15480 --gen_subset_num 24 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 15480 --end_idx_gen 16125 --gen_subset_num 25 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 16125 --end_idx_gen 16770 --gen_subset_num 26 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 16770 --end_idx_gen 17415 --gen_subset_num 27 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 17415 --end_idx_gen 18060 --gen_subset_num 28 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 18060 --end_idx_gen 18705 --gen_subset_num 29 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 18705 --end_idx_gen 19350 --gen_subset_num 30 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 19350 --end_idx_gen 19995 --gen_subset_num 31 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 19995 --end_idx_gen 20640 --gen_subset_num 32 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
wait
|
||||||
|
python eval_visdial.py
|
177
main_stage_1.py
Normal file
177
main_stage_1.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
# from transformers import BartTokenizer
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
|
from utils.init import initialize_from_env
|
||||||
|
# from datasets.pretraining import load_datasets, VideoTextRetDataset
|
||||||
|
# from datasets.utils import get_datasets_media
|
||||||
|
from models.setup import setup_model, setup_data, setup_data_test
|
||||||
|
from tasks.pre_train import pre_train
|
||||||
|
# from tasks.ft_avsd import ft_avsd, generate
|
||||||
|
# from tasks.stage_2_3 import pretrain
|
||||||
|
# from tasks.stage_2 import train as train_stage_2
|
||||||
|
|
||||||
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for v2dial')
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='v2dial/stage_1',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
default='train',
|
||||||
|
help='train, generate or debug'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='/scratch/abdessaied/projects/V2Dial_TU/logs/stage_4/v2dial-flant5_large_bert_experts_4_only_gen_AVSD'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_mode',
|
||||||
|
type=str,
|
||||||
|
default='online',
|
||||||
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_project',
|
||||||
|
type=str,
|
||||||
|
default='V2Dial'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--tag',
|
||||||
|
type=str,
|
||||||
|
# default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD',
|
||||||
|
# default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed',
|
||||||
|
# default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16',
|
||||||
|
default='without_sep_spatial_temporal_experts',
|
||||||
|
# default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs',
|
||||||
|
help="Tag to differentiate the models"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--medium',
|
||||||
|
type=str,
|
||||||
|
default='avsd',
|
||||||
|
help="Medium of the test dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--start_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The start index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--end_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The end index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--gen_subset_num',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The index of the test split for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
def main(gpu, config, args):
|
||||||
|
|
||||||
|
config['gpu'] = gpu
|
||||||
|
if config['distributed']:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
world_size=config['num_gpus'],
|
||||||
|
rank=gpu
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
device = torch.device(f'cuda:{gpu}')
|
||||||
|
if config.use_cpu:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
config['device'] = device
|
||||||
|
# model = V2Dial(config)
|
||||||
|
|
||||||
|
# config['num_training_steps'] = num_step_per_epoch * config['epochs']
|
||||||
|
# config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs']
|
||||||
|
if config['training']:
|
||||||
|
train_dataloaders, val_dataloaders = setup_data(config)
|
||||||
|
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
global_step,
|
||||||
|
webvid_step,
|
||||||
|
cc3m_step,
|
||||||
|
config
|
||||||
|
) = setup_model(config, pretrain=True)
|
||||||
|
|
||||||
|
pre_train(
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
train_dataloaders,
|
||||||
|
val_dataloaders,
|
||||||
|
optimizer,
|
||||||
|
global_step,
|
||||||
|
webvid_step,
|
||||||
|
cc3m_step,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
|
||||||
|
if config['distributed']:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
model, stage = args.model.split('/')
|
||||||
|
config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag)
|
||||||
|
config['wandb_enabled'] = args.wandb_mode == 'online'
|
||||||
|
config['training'] = args.mode == 'train'
|
||||||
|
config['generating'] = args.mode == 'generate'
|
||||||
|
config['debugging'] = args.mode == 'debug'
|
||||||
|
|
||||||
|
config['wandb_mode'] = args.wandb_mode
|
||||||
|
config['medium'] = args.medium
|
||||||
|
config['start_idx_gen'] = args.start_idx_gen
|
||||||
|
config['end_idx_gen'] = args.end_idx_gen
|
||||||
|
|
||||||
|
# config['wandb_project']
|
||||||
|
# if config['accelerator'] == 'ddp':
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
config['distributed'] = True
|
||||||
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||||
|
else:
|
||||||
|
config['distributed'] = False
|
||||||
|
main(0, config, args)
|
186
main_stage_2.py
Normal file
186
main_stage_2.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
# from transformers import BartTokenizer
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
|
from utils.init import initialize_from_env
|
||||||
|
# from datasets.pretraining import load_datasets, VideoTextRetDataset
|
||||||
|
# from datasets.utils import get_datasets_media
|
||||||
|
from models.setup import setup_model, setup_data, setup_data_test
|
||||||
|
# from tasks.ft_avsd import ft_avsd, generate
|
||||||
|
from tasks.stage_2 import train as train_stage_2
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for v2dial')
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='v2dial/stage_2',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
default='train',
|
||||||
|
help='train, generate or debug'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='/scratch/abdessaied/projects/V2Dial_TU/logs/stage_4/v2dial-flant5_large_bert_experts_4_only_gen_AVSD'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_mode',
|
||||||
|
type=str,
|
||||||
|
default='online',
|
||||||
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_project',
|
||||||
|
type=str,
|
||||||
|
default='V2Dial'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--tag',
|
||||||
|
type=str,
|
||||||
|
# default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD',
|
||||||
|
# default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed',
|
||||||
|
# default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16',
|
||||||
|
default='from_stage_1_only_gen_loss_frozen_llm',
|
||||||
|
# default='blub',
|
||||||
|
# default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs',
|
||||||
|
help="Tag to differentiate the models"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--medium',
|
||||||
|
type=str,
|
||||||
|
default='avsd',
|
||||||
|
help="Medium of the test dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--start_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The start index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--end_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The end index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--gen_subset_num',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The index of the test split for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
def main(gpu, config, args):
|
||||||
|
|
||||||
|
config['gpu'] = gpu
|
||||||
|
if config['distributed']:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
world_size=config['num_gpus'],
|
||||||
|
rank=gpu
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
device = torch.device(f'cuda:{gpu}')
|
||||||
|
if config.use_cpu:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
config['device'] = device
|
||||||
|
# model = V2Dial(config)
|
||||||
|
|
||||||
|
# config['num_training_steps'] = num_step_per_epoch * config['epochs']
|
||||||
|
# config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs']
|
||||||
|
if config['training']:
|
||||||
|
train_dataloaders, val_dataloaders = setup_data(config)
|
||||||
|
|
||||||
|
(
|
||||||
|
model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, config
|
||||||
|
) = setup_model(config)
|
||||||
|
|
||||||
|
if config['training']:
|
||||||
|
train_stage_2(
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
train_dataloaders,
|
||||||
|
val_dataloaders,
|
||||||
|
optimizer,
|
||||||
|
global_step,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
|
||||||
|
# if config['stage'] == 'stage_3':
|
||||||
|
# (
|
||||||
|
# model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, config
|
||||||
|
# ) = setup_model(config)
|
||||||
|
# if config['training']:
|
||||||
|
# ft_avsd(
|
||||||
|
# model,
|
||||||
|
# model_without_ddp,
|
||||||
|
# train_dataloaders,
|
||||||
|
# val_dataloaders,
|
||||||
|
# optimizer,
|
||||||
|
# global_step,
|
||||||
|
# scheduler,
|
||||||
|
# scaler,
|
||||||
|
# start_epoch,
|
||||||
|
# config
|
||||||
|
# )
|
||||||
|
# elif config['generating']:
|
||||||
|
# test_dataloader = setup_data_test(config, args)
|
||||||
|
# generate(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||||
|
|
||||||
|
if config['distributed']:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
model, stage = args.model.split('/')
|
||||||
|
config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag)
|
||||||
|
config['wandb_enabled'] = args.wandb_mode == 'online'
|
||||||
|
config['training'] = args.mode == 'train'
|
||||||
|
config['generating'] = args.mode == 'generate'
|
||||||
|
config['debugging'] = args.mode == 'debug'
|
||||||
|
|
||||||
|
config['wandb_mode'] = args.wandb_mode
|
||||||
|
config['medium'] = args.medium
|
||||||
|
config['start_idx_gen'] = args.start_idx_gen
|
||||||
|
config['end_idx_gen'] = args.end_idx_gen
|
||||||
|
|
||||||
|
# config['wandb_project']
|
||||||
|
# if config['accelerator'] == 'ddp':
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
config['distributed'] = True
|
||||||
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||||
|
else:
|
||||||
|
config['distributed'] = False
|
||||||
|
main(0, config, args)
|
185
main_stage_3.py
Normal file
185
main_stage_3.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
# from transformers import BartTokenizer
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
|
from utils.init import initialize_from_env
|
||||||
|
# from datasets.pretraining import load_datasets, VideoTextRetDataset
|
||||||
|
# from datasets.utils import get_datasets_media
|
||||||
|
from models.setup import setup_model, setup_data, setup_data_test
|
||||||
|
# from tasks.ft_avsd import ft_avsd, generate
|
||||||
|
from tasks.stage_3 import ft_avsd, generate, generate_nextqa, generate_visdial
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for v2dial')
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='v2dial/stage_3',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
default='generate',
|
||||||
|
help='train, generate or debug'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetune_without_stc_stm_only_visdial'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_mode',
|
||||||
|
type=str,
|
||||||
|
default='online',
|
||||||
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_project',
|
||||||
|
type=str,
|
||||||
|
default='V2Dial'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--tag',
|
||||||
|
type=str,
|
||||||
|
default="finetuned_visdial_without_stm_stc",
|
||||||
|
# default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD',
|
||||||
|
# default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed',
|
||||||
|
# default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16',
|
||||||
|
# default='from_stage1_after_avsd_only_visdial_4_frames_10_rounds_ft',
|
||||||
|
# default='from_scratch_visdial',
|
||||||
|
# default='no_moes_div_st_from_scratch_only_avsd_4_frames_3_rounds_ft_fp16',
|
||||||
|
# default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs',
|
||||||
|
help="Tag to differentiate the models"
|
||||||
|
)
|
||||||
|
|
||||||
|
# parser.add_argument(
|
||||||
|
# '--medium',
|
||||||
|
# type=str,
|
||||||
|
# default='avsd',
|
||||||
|
# help="Medium of the test dataset"
|
||||||
|
# )
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--start_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The start index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--end_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The end index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--gen_subset_num',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The index of the test split for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
def main(gpu, config, args):
|
||||||
|
|
||||||
|
config['gpu'] = gpu
|
||||||
|
if config['distributed']:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
world_size=config['num_gpus'],
|
||||||
|
rank=gpu
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
|
||||||
|
device = torch.device(f'cuda:{gpu}')
|
||||||
|
if config.use_cpu:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
config['device'] = device
|
||||||
|
# model = V2Dial(config)
|
||||||
|
|
||||||
|
# config['num_training_steps'] = num_step_per_epoch * config['epochs']
|
||||||
|
# config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs']
|
||||||
|
if config['training']:
|
||||||
|
train_dataloaders, val_dataloaders = setup_data(config)
|
||||||
|
|
||||||
|
(
|
||||||
|
model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, visdial_step, avsd_step, nextqa_step, config
|
||||||
|
) = setup_model(config)
|
||||||
|
|
||||||
|
if config['training']:
|
||||||
|
ft_avsd(
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
train_dataloaders,
|
||||||
|
val_dataloaders,
|
||||||
|
optimizer,
|
||||||
|
global_step,
|
||||||
|
visdial_step,
|
||||||
|
avsd_step,
|
||||||
|
nextqa_step,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
elif config['generating']:
|
||||||
|
test_dataloader = setup_data_test(config)
|
||||||
|
if config.media_test == 'avsd':
|
||||||
|
generate(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||||
|
if config.media_test == 'visdial':
|
||||||
|
generate_visdial(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||||
|
elif config.media_test == 'nextqa':
|
||||||
|
generate_nextqa(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||||
|
|
||||||
|
if config['distributed']:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
model, stage = args.model.split('/')
|
||||||
|
config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag)
|
||||||
|
config['wandb_enabled'] = args.wandb_mode == 'online'
|
||||||
|
config['training'] = args.mode == 'train'
|
||||||
|
config['generating'] = args.mode == 'generate'
|
||||||
|
config['debugging'] = args.mode == 'debug'
|
||||||
|
|
||||||
|
config['wandb_mode'] = args.wandb_mode
|
||||||
|
# config['medium'] = args.medium
|
||||||
|
config['start_idx_gen'] = args.start_idx_gen
|
||||||
|
config['end_idx_gen'] = args.end_idx_gen
|
||||||
|
config['expert_permutation'] = None
|
||||||
|
# config['expert_permutation'] = {
|
||||||
|
# 'spatial': 'history',
|
||||||
|
# 'temporal': 'temporal',
|
||||||
|
# 'caption': 'caption',
|
||||||
|
# 'history': 'spatial'
|
||||||
|
# }
|
||||||
|
|
||||||
|
# config['wandb_project']
|
||||||
|
# if config['accelerator'] == 'ddp':
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
config['distributed'] = True
|
||||||
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||||
|
else:
|
||||||
|
config['distributed'] = False
|
||||||
|
main(0, config, args)
|
61
merge_pred_avsd.py
Normal file
61
merge_pred_avsd.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dstc',
|
||||||
|
type=int,
|
||||||
|
default=7,
|
||||||
|
choices=[7, 8, 10],
|
||||||
|
help='DSTC challenge identifier')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.dstc in [7, 8, 10]
|
||||||
|
if args.dstc == 7:
|
||||||
|
output_dir = 'output/dstc7'
|
||||||
|
raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC7-AVSD.json'
|
||||||
|
|
||||||
|
elif args.dstc == 8:
|
||||||
|
output_dir = 'output/dstc8'
|
||||||
|
raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC8-AVSD.json'
|
||||||
|
else:
|
||||||
|
output_dir = 'output/dstc10'
|
||||||
|
raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC10-AVSD.json'
|
||||||
|
|
||||||
|
with open(raw_data_path, 'r') as f:
|
||||||
|
raw_dialogs = json.load(f)['dialogs']
|
||||||
|
|
||||||
|
file_paths = os.listdir(output_dir)
|
||||||
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||||
|
name = file_paths[0]
|
||||||
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||||
|
|
||||||
|
dialogs = {}
|
||||||
|
for pth in file_paths:
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for dialog in data['dialogs']:
|
||||||
|
vid_id = dialog['image_id']
|
||||||
|
dialogs[vid_id] = dialog
|
||||||
|
# dialogs.extend(data['dialogs'])
|
||||||
|
os.remove(pth)
|
||||||
|
|
||||||
|
# Now, re-establish the original order of the dialogs
|
||||||
|
res = []
|
||||||
|
for dialog in raw_dialogs:
|
||||||
|
vid_id = dialog['image_id']
|
||||||
|
res.append(dialogs[vid_id])
|
||||||
|
|
||||||
|
res = {
|
||||||
|
'dialogs': res
|
||||||
|
}
|
||||||
|
|
||||||
|
name = "".join(name.split('-')[:-1]) + '.json'
|
||||||
|
output_path = os.path.join(output_dir, name)
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(res, f, indent=4)
|
||||||
|
|
||||||
|
print('[INFO] Files merged and saved in {}'.format(output_path))
|
34
merge_pred_nextqa.py
Normal file
34
merge_pred_nextqa.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_dir = 'output/nextqa'
|
||||||
|
|
||||||
|
file_paths = os.listdir(output_dir)
|
||||||
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||||
|
name = file_paths[0]
|
||||||
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for pth in file_paths:
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for video_id in data:
|
||||||
|
if video_id not in results:
|
||||||
|
results[video_id] = data[video_id]
|
||||||
|
else:
|
||||||
|
for qid in data[video_id]:
|
||||||
|
if qid not in results[video_id]:
|
||||||
|
results[video_id][qid] = data[video_id][qid]
|
||||||
|
os.remove(pth)
|
||||||
|
|
||||||
|
name = "".join(name.split('-')[:-1]) + '.json'
|
||||||
|
output_path = os.path.join(output_dir, name)
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
print('[INFO] Files merged and saved in {}'.format(output_path))
|
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
1216
models/backbones/Qformer.py
Executable file
1216
models/backbones/Qformer.py
Executable file
File diff suppressed because it is too large
Load diff
0
models/backbones/__init__.py
Normal file
0
models/backbones/__init__.py
Normal file
247
models/backbones/base_model.py
Executable file
247
models/backbones/base_model.py
Executable file
|
@ -0,0 +1,247 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||||
|
from models.common.utils import get_abs_path, is_url
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(nn.Module):
|
||||||
|
"""Base class for models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return list(self.parameters())[0].device
|
||||||
|
|
||||||
|
def load_checkpoint(self, url_or_filename):
|
||||||
|
"""
|
||||||
|
Load from a finetuned checkpoint.
|
||||||
|
|
||||||
|
This should expect no mismatch in the model keys and the checkpoint keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if is_url(url_or_filename):
|
||||||
|
cached_file = download_cached_file(
|
||||||
|
url_or_filename, check_hash=False, progress=True
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("checkpoint url or path is invalid")
|
||||||
|
|
||||||
|
if "model" in checkpoint.keys():
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
msg = self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
logging.info("Missing keys {}".format(msg.missing_keys))
|
||||||
|
logging.info("load checkpoint from %s" % url_or_filename)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_type):
|
||||||
|
"""
|
||||||
|
Build a pretrained model from default configuration file, specified by model_type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- model_type (str): model type, specifying architecture and checkpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
||||||
|
"""
|
||||||
|
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
||||||
|
model = cls.from_config(model_cfg)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_config_path(cls, model_type):
|
||||||
|
assert (
|
||||||
|
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
||||||
|
), "Unknown model type {}".format(model_type)
|
||||||
|
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
||||||
|
|
||||||
|
def load_checkpoint_from_config(self, cfg, **kwargs):
|
||||||
|
"""
|
||||||
|
Load checkpoint as specified in the config file.
|
||||||
|
|
||||||
|
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
||||||
|
When loading the pretrained model, each task-specific architecture may define their
|
||||||
|
own load_from_pretrained() method.
|
||||||
|
"""
|
||||||
|
load_finetuned = cfg.get("load_finetuned", True)
|
||||||
|
if load_finetuned:
|
||||||
|
finetune_path = cfg.get("finetuned", None)
|
||||||
|
assert (
|
||||||
|
finetune_path is not None
|
||||||
|
), "Found load_finetuned is True, but finetune_path is None."
|
||||||
|
self.load_checkpoint(url_or_filename=finetune_path)
|
||||||
|
else:
|
||||||
|
# load pre-trained weights
|
||||||
|
pretrain_path = cfg.get("pretrained", None)
|
||||||
|
assert "Found load_finetuned is False, but pretrain_path is None."
|
||||||
|
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
||||||
|
|
||||||
|
def before_evaluation(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def show_n_params(self, return_str=True):
|
||||||
|
tot = 0
|
||||||
|
for p in self.parameters():
|
||||||
|
w = 1
|
||||||
|
for x in p.shape:
|
||||||
|
w *= x
|
||||||
|
tot += w
|
||||||
|
if return_str:
|
||||||
|
if tot >= 1e6:
|
||||||
|
return "{:.1f}M".format(tot / 1e6)
|
||||||
|
else:
|
||||||
|
return "{:.1f}K".format(tot / 1e3)
|
||||||
|
else:
|
||||||
|
return tot
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward_features(self, samples, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return list(self.parameters())[0].device
|
||||||
|
|
||||||
|
|
||||||
|
class SharedQueueMixin:
|
||||||
|
@torch.no_grad()
|
||||||
|
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
||||||
|
# gather keys before updating queue
|
||||||
|
image_feats = concat_all_gather(image_feat)
|
||||||
|
text_feats = concat_all_gather(text_feat)
|
||||||
|
|
||||||
|
batch_size = image_feats.shape[0]
|
||||||
|
|
||||||
|
ptr = int(self.queue_ptr)
|
||||||
|
assert self.queue_size % batch_size == 0 # for simplicity
|
||||||
|
|
||||||
|
# replace the keys at ptr (dequeue and enqueue)
|
||||||
|
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
||||||
|
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
||||||
|
|
||||||
|
if idxs is not None:
|
||||||
|
idxs = concat_all_gather(idxs)
|
||||||
|
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
||||||
|
|
||||||
|
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||||
|
self.queue_ptr[0] = ptr
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumDistilationMixin:
|
||||||
|
@torch.no_grad()
|
||||||
|
def copy_params(self):
|
||||||
|
for model_pair in self.model_pairs:
|
||||||
|
for param, param_m in zip(
|
||||||
|
model_pair[0].parameters(), model_pair[1].parameters()
|
||||||
|
):
|
||||||
|
param_m.data.copy_(param.data) # initialize
|
||||||
|
param_m.requires_grad = False # not update by gradient
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _momentum_update(self):
|
||||||
|
for model_pair in self.model_pairs:
|
||||||
|
for param, param_m in zip(
|
||||||
|
model_pair[0].parameters(), model_pair[1].parameters()
|
||||||
|
):
|
||||||
|
param_m.data = param_m.data * self.momentum + param.data * (
|
||||||
|
1.0 - self.momentum
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GatherLayer(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gather tensors from all workers with support for backward propagation:
|
||||||
|
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
output = [
|
||||||
|
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(output, x)
|
||||||
|
return tuple(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grads):
|
||||||
|
all_gradients = torch.stack(grads)
|
||||||
|
torch.distributed.all_reduce(all_gradients)
|
||||||
|
return all_gradients[torch.distributed.get_rank()]
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_with_grad(tensors):
|
||||||
|
"""
|
||||||
|
Performs all_gather operation on the provided tensors.
|
||||||
|
Graph remains connected for backward grad computation.
|
||||||
|
"""
|
||||||
|
# Queue the gathered tensors
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
# There is no need for reduction in the single-proc case
|
||||||
|
if world_size == 1:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
# tensor_all = GatherLayer.apply(tensors)
|
||||||
|
tensor_all = GatherLayer.apply(tensors)
|
||||||
|
|
||||||
|
return torch.cat(tensor_all, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def concat_all_gather(tensor):
|
||||||
|
"""
|
||||||
|
Performs all_gather operation on the provided tensors.
|
||||||
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||||
|
"""
|
||||||
|
# if use distributed training
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
tensors_gather = [
|
||||||
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||||
|
|
||||||
|
output = torch.cat(tensors_gather, dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def tile(x, dim, n_tile):
|
||||||
|
init_dim = x.size(dim)
|
||||||
|
repeat_idx = [1] * x.dim()
|
||||||
|
repeat_idx[dim] = n_tile
|
||||||
|
x = x.repeat(*(repeat_idx))
|
||||||
|
order_index = torch.LongTensor(
|
||||||
|
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||||
|
)
|
||||||
|
return torch.index_select(x, dim, order_index.to(x.device))
|
0
models/backbones/beit/__init__.py
Normal file
0
models/backbones/beit/__init__.py
Normal file
107
models/backbones/beit/builder.py
Normal file
107
models/backbones/beit/builder.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from models.utils import (interpolate_pos_relative_bias_beit,
|
||||||
|
load_temp_embed_with_mismatch)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_embed_beit(state_dict, new_model):
|
||||||
|
"""interpolate the positional embeddings.
|
||||||
|
The spatial pe is relative and temporal pe is absolute.
|
||||||
|
additional temporal pe is padded with 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): The state_dict.
|
||||||
|
new_model (nn.Module): The created model.
|
||||||
|
|
||||||
|
Returns: dict. The state_dict with updated positional embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
state_dict = interpolate_pos_relative_bias_beit(
|
||||||
|
state_dict_old=state_dict,
|
||||||
|
state_dict_new=new_model.state_dict(),
|
||||||
|
patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape,
|
||||||
|
)
|
||||||
|
# absolute temporal pos bias
|
||||||
|
temporal_pe_key = "beit.embeddings.temporal_position_embeddings"
|
||||||
|
if temporal_pe_key in state_dict:
|
||||||
|
logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
|
||||||
|
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
|
||||||
|
temp_embed_old=state_dict[temporal_pe_key],
|
||||||
|
temp_embed_new=new_model.state_dict()[temporal_pe_key],
|
||||||
|
)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def extract_beit_from_vindlu(vindlu_state_dict):
|
||||||
|
beit_state_dict = {}
|
||||||
|
beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k]
|
||||||
|
for param_name in beit_param_names:
|
||||||
|
new_name = param_name.replace('vision_encoder.', '')
|
||||||
|
beit_state_dict[new_name] = vindlu_state_dict[param_name]
|
||||||
|
|
||||||
|
return beit_state_dict
|
||||||
|
|
||||||
|
def build_beit(model_config, image_res, checkpoint=False):
|
||||||
|
"""build beit with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The configs for beit.
|
||||||
|
image_res (int): The image resolution.
|
||||||
|
checkpoint (bool): Whether to enable gradient checkpointing.
|
||||||
|
|
||||||
|
Returns: nn.Module
|
||||||
|
|
||||||
|
"""
|
||||||
|
from .st_beit import BeitConfig as config_cls
|
||||||
|
from .st_beit import BeitModel as model_cls
|
||||||
|
|
||||||
|
|
||||||
|
vindlu_state_dict = torch.load(model_config['vindlu_path'])['model']
|
||||||
|
state_dict = extract_beit_from_vindlu(vindlu_state_dict)
|
||||||
|
model_config = model_config['beit_config_json']
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}."
|
||||||
|
)
|
||||||
|
# BEiT uses average pooled tokens instead of [CLS] used by other models
|
||||||
|
aux_kwargs = {"add_pooling_layer": True}
|
||||||
|
# tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs)
|
||||||
|
# state_dict = tmp_model.state_dict()
|
||||||
|
|
||||||
|
# del tmp_model
|
||||||
|
|
||||||
|
logger.info(f"Init new model with new image size {image_res}, and load weights.")
|
||||||
|
|
||||||
|
# other_cfg = model_config.temporal_modeling
|
||||||
|
other_cfg = {}
|
||||||
|
|
||||||
|
vit_config = config_cls.from_pretrained(
|
||||||
|
model_config['pretrained'], image_size=image_res, **other_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
# vit_config.update(model_config)
|
||||||
|
|
||||||
|
model = model_cls(config=vit_config, **aux_kwargs)
|
||||||
|
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# interpolate relative pos bias
|
||||||
|
state_dict = interpolate_pos_relative_bias_beit(
|
||||||
|
state_dict_old=state_dict,
|
||||||
|
state_dict_new=model.state_dict(),
|
||||||
|
patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# del prompt_bias_table
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if "prompt_bias_table" in k:
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
msg = model.load_state_dict(state_dict, strict=False)
|
||||||
|
logger.info(msg)
|
||||||
|
return model
|
1752
models/backbones/beit/st_beit.py
Normal file
1752
models/backbones/beit/st_beit.py
Normal file
File diff suppressed because it is too large
Load diff
0
models/backbones/bert/__init__.py
Normal file
0
models/backbones/bert/__init__.py
Normal file
71
models/backbones/bert/builder.py
Normal file
71
models/backbones/bert/builder.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
|
||||||
|
|
||||||
|
|
||||||
|
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
|
||||||
|
"""build text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
pretrain (bool): Whether to do pretrain or finetuning.
|
||||||
|
checkpoint (bool): whether to do gradient_checkpointing.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
bert_size = model_config['expert_size']
|
||||||
|
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
|
||||||
|
# bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||||
|
bert_config.gradient_checkpointing = checkpoint
|
||||||
|
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
|
||||||
|
if expert_type=='modality':
|
||||||
|
if modality_type == 'vis':
|
||||||
|
bert_config.cross_attention_freq = 2
|
||||||
|
else:
|
||||||
|
bert_config.cross_attention_freq = -1
|
||||||
|
else:
|
||||||
|
bert_config.cross_attention_freq = 1
|
||||||
|
|
||||||
|
if pretrain:
|
||||||
|
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
|
||||||
|
f'bert-{bert_size}-uncased',
|
||||||
|
config=bert_config,
|
||||||
|
output_loading_info=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_encoder, loading_info = BertModel.from_pretrained(
|
||||||
|
f'bert-{bert_size}-uncased',
|
||||||
|
config=bert_config,
|
||||||
|
add_pooling_layer=True,
|
||||||
|
output_loading_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def build_bert_decoder(model_config, checkpoint):
|
||||||
|
"""build text decoder the same as the multimodal encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
pretrain (bool): Whether to do pretrain or finetuning.
|
||||||
|
checkpoint (bool): whether to do gradient_checkpointing.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
|
||||||
|
bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||||
|
bert_config.gradient_checkpointing = checkpoint
|
||||||
|
|
||||||
|
bert_config.fusion_layer = 0
|
||||||
|
bert_config.num_hidden_layers = (
|
||||||
|
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
|
||||||
|
model_config.text_encoder.pretrained,
|
||||||
|
config=bert_config,
|
||||||
|
output_loading_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_decoder
|
546
models/backbones/bert/tokenization_bert.py
Normal file
546
models/backbones/bert/tokenization_bert.py
Normal file
|
@ -0,0 +1,546 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes for Bert."""
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
|
||||||
|
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
|
||||||
|
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
|
||||||
|
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
||||||
|
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
|
||||||
|
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
|
||||||
|
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
|
||||||
|
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
|
||||||
|
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
|
||||||
|
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
|
||||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
||||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
||||||
|
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
|
||||||
|
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
|
||||||
|
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
|
||||||
|
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
|
||||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
|
||||||
|
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"bert-base-uncased": 512,
|
||||||
|
"bert-large-uncased": 512,
|
||||||
|
"bert-base-cased": 512,
|
||||||
|
"bert-large-cased": 512,
|
||||||
|
"bert-base-multilingual-uncased": 512,
|
||||||
|
"bert-base-multilingual-cased": 512,
|
||||||
|
"bert-base-chinese": 512,
|
||||||
|
"bert-base-german-cased": 512,
|
||||||
|
"bert-large-uncased-whole-word-masking": 512,
|
||||||
|
"bert-large-cased-whole-word-masking": 512,
|
||||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
||||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
||||||
|
"bert-base-cased-finetuned-mrpc": 512,
|
||||||
|
"bert-base-german-dbmdz-cased": 512,
|
||||||
|
"bert-base-german-dbmdz-uncased": 512,
|
||||||
|
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
||||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
||||||
|
"wietsedv/bert-base-dutch-cased": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
"bert-base-uncased": {"do_lower_case": True},
|
||||||
|
"bert-large-uncased": {"do_lower_case": True},
|
||||||
|
"bert-base-cased": {"do_lower_case": False},
|
||||||
|
"bert-large-cased": {"do_lower_case": False},
|
||||||
|
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
||||||
|
"bert-base-multilingual-cased": {"do_lower_case": False},
|
||||||
|
"bert-base-chinese": {"do_lower_case": False},
|
||||||
|
"bert-base-german-cased": {"do_lower_case": False},
|
||||||
|
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
||||||
|
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
||||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
||||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
||||||
|
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
||||||
|
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
||||||
|
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
||||||
|
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
||||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
||||||
|
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_file):
|
||||||
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||||
|
tokens = reader.readlines()
|
||||||
|
for index, token in enumerate(tokens):
|
||||||
|
token = token.rstrip("\n")
|
||||||
|
vocab[token] = index
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_tokenize(text):
|
||||||
|
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizer(PreTrainedTokenizer):
|
||||||
|
r"""
|
||||||
|
Construct a BERT tokenizer. Based on WordPiece.
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||||
|
Users should refer to this superclass for more information regarding those methods.
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
File containing the vocabulary.
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to lowercase the input when tokenizing.
|
||||||
|
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to do basic tokenization before WordPiece.
|
||||||
|
never_split (:obj:`Iterable`, `optional`):
|
||||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||||
|
:obj:`do_basic_tokenize=True`
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masked language
|
||||||
|
modeling. This is the token which the model will try to predict.
|
||||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to tokenize Chinese characters.
|
||||||
|
This should likely be deactivated for Japanese (see this `issue
|
||||||
|
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||||
|
strip_accents: (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||||
|
value for :obj:`lowercase` (as in the original BERT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
do_lower_case=True,
|
||||||
|
do_basic_tokenize=True,
|
||||||
|
never_split=None,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
sep_token="[SEP]",
|
||||||
|
pad_token="[PAD]",
|
||||||
|
cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]",
|
||||||
|
tokenize_chinese_chars=True,
|
||||||
|
strip_accents=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
do_basic_tokenize=do_basic_tokenize,
|
||||||
|
never_split=never_split,
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.isfile(vocab_file):
|
||||||
|
raise ValueError(
|
||||||
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
|
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||||
|
vocab_file)
|
||||||
|
)
|
||||||
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
self.ids_to_tokens = collections.OrderedDict(
|
||||||
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
|
self.do_basic_tokenize = do_basic_tokenize
|
||||||
|
if do_basic_tokenize:
|
||||||
|
self.basic_tokenizer = BasicTokenizer(
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
never_split=never_split,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
)
|
||||||
|
self.wordpiece_tokenizer = WordpieceTokenizer(
|
||||||
|
vocab=self.vocab, unk_token=self.unk_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_lower_case(self):
|
||||||
|
return self.basic_tokenizer.do_lower_case
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocab)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
|
def _tokenize(self, text):
|
||||||
|
split_tokens = []
|
||||||
|
if self.do_basic_tokenize:
|
||||||
|
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||||
|
|
||||||
|
# If the token is part of the never_split set
|
||||||
|
if token in self.basic_tokenizer.never_split:
|
||||||
|
split_tokens.append(token)
|
||||||
|
else:
|
||||||
|
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
||||||
|
else:
|
||||||
|
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||||
|
return split_tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
return self.ids_to_tokens.get(index, self.unk_token)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
|
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
|
adding special tokens. A BERT sequence has the following format:
|
||||||
|
- single sequence: ``[CLS] X ``
|
||||||
|
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return [self.cls_token_id] + token_ids_0
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formatted with special tokens for the model."
|
||||||
|
)
|
||||||
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
|
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
|
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
||||||
|
pair mask has the following format:
|
||||||
|
::
|
||||||
|
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||||
|
| first sequence | second sequence |
|
||||||
|
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||||
|
sequence(s).
|
||||||
|
"""
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
index = 0
|
||||||
|
if os.path.isdir(save_directory):
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") +
|
||||||
|
VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vocab_file = (filename_prefix +
|
||||||
|
"-" if filename_prefix else "") + save_directory
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning(
|
||||||
|
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||||
|
" Please check that the vocabulary is not corrupted!".format(
|
||||||
|
vocab_file)
|
||||||
|
)
|
||||||
|
index = token_index
|
||||||
|
writer.write(token + "\n")
|
||||||
|
index += 1
|
||||||
|
return (vocab_file,)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTokenizer(object):
|
||||||
|
"""
|
||||||
|
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
||||||
|
Args:
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to lowercase the input when tokenizing.
|
||||||
|
never_split (:obj:`Iterable`, `optional`):
|
||||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||||
|
:obj:`do_basic_tokenize=True`
|
||||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to tokenize Chinese characters.
|
||||||
|
This should likely be deactivated for Japanese (see this `issue
|
||||||
|
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||||
|
strip_accents: (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||||
|
value for :obj:`lowercase` (as in the original BERT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
||||||
|
if never_split is None:
|
||||||
|
never_split = []
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.never_split = set(never_split)
|
||||||
|
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||||
|
self.strip_accents = strip_accents
|
||||||
|
|
||||||
|
def tokenize(self, text, never_split=None):
|
||||||
|
"""
|
||||||
|
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
||||||
|
WordPieceTokenizer.
|
||||||
|
Args:
|
||||||
|
**never_split**: (`optional`) list of str
|
||||||
|
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
||||||
|
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
|
||||||
|
"""
|
||||||
|
# union() returns a new set by concatenating the two sets.
|
||||||
|
never_split = self.never_split.union(
|
||||||
|
set(never_split)) if never_split else self.never_split
|
||||||
|
text = self._clean_text(text)
|
||||||
|
|
||||||
|
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||||
|
# models. This is also applied to the English models now, but it doesn't
|
||||||
|
# matter since the English models were not trained on any Chinese data
|
||||||
|
# and generally don't have any Chinese data in them (there are Chinese
|
||||||
|
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||||
|
# words in the English Wikipedia.).
|
||||||
|
if self.tokenize_chinese_chars:
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
orig_tokens = whitespace_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if token not in never_split:
|
||||||
|
if self.do_lower_case:
|
||||||
|
token = token.lower()
|
||||||
|
if self.strip_accents is not False:
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
elif self.strip_accents:
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
||||||
|
|
||||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""Strips accents from a piece of text."""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text, never_split=None):
|
||||||
|
"""Splits punctuation on a piece of text."""
|
||||||
|
if never_split is not None and text in never_split:
|
||||||
|
return [text]
|
||||||
|
chars = list(text)
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
while i < len(chars):
|
||||||
|
char = chars[i]
|
||||||
|
if _is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""Adds whitespace around any CJK character."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self._is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if (
|
||||||
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
|
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||||
|
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||||
|
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||||
|
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||||
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||||
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||||
|
): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||||
|
continue
|
||||||
|
if _is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
class WordpieceTokenizer(object):
|
||||||
|
"""Runs WordPiece tokenization."""
|
||||||
|
|
||||||
|
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.unk_token = unk_token
|
||||||
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""
|
||||||
|
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
||||||
|
tokenization using the given vocabulary.
|
||||||
|
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
||||||
|
Args:
|
||||||
|
text: A single token or whitespace separated tokens. This should have
|
||||||
|
already been passed through `BasicTokenizer`.
|
||||||
|
Returns:
|
||||||
|
A list of wordpiece tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for token in whitespace_tokenize(text):
|
||||||
|
chars = list(token)
|
||||||
|
if len(chars) > self.max_input_chars_per_word:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_bad = False
|
||||||
|
start = 0
|
||||||
|
sub_tokens = []
|
||||||
|
while start < len(chars):
|
||||||
|
end = len(chars)
|
||||||
|
cur_substr = None
|
||||||
|
while start < end:
|
||||||
|
substr = "".join(chars[start:end])
|
||||||
|
if start > 0:
|
||||||
|
substr = "##" + substr
|
||||||
|
if substr in self.vocab:
|
||||||
|
cur_substr = substr
|
||||||
|
break
|
||||||
|
end -= 1
|
||||||
|
if cur_substr is None:
|
||||||
|
is_bad = True
|
||||||
|
break
|
||||||
|
sub_tokens.append(cur_substr)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if is_bad:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
else:
|
||||||
|
output_tokens.extend(sub_tokens)
|
||||||
|
return output_tokens
|
2160
models/backbones/bert/xbert.py
Normal file
2160
models/backbones/bert/xbert.py
Normal file
File diff suppressed because it is too large
Load diff
268
models/backbones/blip2.py
Executable file
268
models/backbones/blip2.py
Executable file
|
@ -0,0 +1,268 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2023, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# import .backbones.common.dist_utils as dist_utils
|
||||||
|
# from minigpt4.common.dist_utils import download_cached_file
|
||||||
|
# from minigpt4.common.utils import is_url
|
||||||
|
# from minigpt4.common.logger import MetricLogger
|
||||||
|
|
||||||
|
from models.backbones.base_model import BaseModel
|
||||||
|
from models.backbones.Qformer import BertConfig, BertLMHeadModel
|
||||||
|
from models.backbones.eva_vit import create_eva_vit_g
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2Base(BaseModel):
|
||||||
|
@classmethod
|
||||||
|
def init_tokenizer(cls):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def maybe_autocast(self, dtype=torch.float16):
|
||||||
|
# if on cpu, don't use autocast
|
||||||
|
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||||
|
enable_autocast = self.device != torch.device("cpu")
|
||||||
|
|
||||||
|
if enable_autocast:
|
||||||
|
return torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
else:
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
||||||
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||||
|
encoder_config.encoder_width = vision_width
|
||||||
|
# insert cross-attention layer every other block
|
||||||
|
encoder_config.add_cross_attention = True
|
||||||
|
encoder_config.cross_attention_freq = cross_attention_freq
|
||||||
|
encoder_config.query_length = num_query_token
|
||||||
|
Qformer = BertLMHeadModel(config=encoder_config)
|
||||||
|
query_tokens = nn.Parameter(
|
||||||
|
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||||
|
)
|
||||||
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||||
|
return Qformer, query_tokens
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_vision_encoder(
|
||||||
|
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||||
|
):
|
||||||
|
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||||
|
visual_encoder = create_eva_vit_g(
|
||||||
|
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||||
|
)
|
||||||
|
|
||||||
|
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||||
|
return visual_encoder, ln_vision
|
||||||
|
|
||||||
|
def load_from_pretrained(self, url_or_filename):
|
||||||
|
if is_url(url_or_filename):
|
||||||
|
cached_file = download_cached_file(
|
||||||
|
url_or_filename, check_hash=False, progress=True
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("checkpoint url or path is invalid")
|
||||||
|
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
|
||||||
|
msg = self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# logging.info("Missing keys {}".format(msg.missing_keys))
|
||||||
|
logging.info("load checkpoint from %s" % url_or_filename)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
||||||
|
|
||||||
|
vit_num_layers = self.visual_encoder.get_num_layer()
|
||||||
|
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
|
||||||
|
|
||||||
|
parameter_group_names = {}
|
||||||
|
parameter_group_vars = {}
|
||||||
|
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue # frozen weights
|
||||||
|
if len(param.shape) == 1 or name.endswith(".bias"):
|
||||||
|
group_name = "no_decay"
|
||||||
|
this_weight_decay = 0.
|
||||||
|
else:
|
||||||
|
group_name = "decay"
|
||||||
|
this_weight_decay = weight_decay
|
||||||
|
if 'visual_encoder' in name:
|
||||||
|
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
|
||||||
|
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
|
||||||
|
else:
|
||||||
|
layer_id = None
|
||||||
|
|
||||||
|
if group_name not in parameter_group_names:
|
||||||
|
if layer_id is not None:
|
||||||
|
scale = lr_scales[layer_id]
|
||||||
|
else:
|
||||||
|
scale = 1
|
||||||
|
parameter_group_names[group_name] = {
|
||||||
|
"weight_decay": this_weight_decay,
|
||||||
|
"params": [],
|
||||||
|
"lr_scale": scale
|
||||||
|
}
|
||||||
|
parameter_group_vars[group_name] = {
|
||||||
|
"weight_decay": this_weight_decay,
|
||||||
|
"params": [],
|
||||||
|
"lr_scale": scale
|
||||||
|
}
|
||||||
|
parameter_group_vars[group_name]["params"].append(param)
|
||||||
|
parameter_group_names[group_name]["params"].append(name)
|
||||||
|
# import json
|
||||||
|
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
||||||
|
optim_params = list(parameter_group_vars.values())
|
||||||
|
return optim_params
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
ret = super().forward(x.type(torch.float32))
|
||||||
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sim_matrix(model, data_loader, **kwargs):
|
||||||
|
k_test = kwargs.pop("k_test")
|
||||||
|
|
||||||
|
metric_logger = MetricLogger(delimiter=" ")
|
||||||
|
header = "Evaluation:"
|
||||||
|
|
||||||
|
logging.info("Computing features for evaluation...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
texts = data_loader.dataset.text
|
||||||
|
num_text = len(texts)
|
||||||
|
text_bs = 256
|
||||||
|
text_ids = []
|
||||||
|
text_embeds = []
|
||||||
|
text_atts = []
|
||||||
|
for i in range(0, num_text, text_bs):
|
||||||
|
text = texts[i : min(num_text, i + text_bs)]
|
||||||
|
text_input = model.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=35,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(model.device)
|
||||||
|
text_feat = model.forward_text(text_input)
|
||||||
|
text_embed = F.normalize(model.text_proj(text_feat))
|
||||||
|
text_embeds.append(text_embed)
|
||||||
|
text_ids.append(text_input.input_ids)
|
||||||
|
text_atts.append(text_input.attention_mask)
|
||||||
|
|
||||||
|
text_embeds = torch.cat(text_embeds, dim=0)
|
||||||
|
text_ids = torch.cat(text_ids, dim=0)
|
||||||
|
text_atts = torch.cat(text_atts, dim=0)
|
||||||
|
|
||||||
|
vit_feats = []
|
||||||
|
image_embeds = []
|
||||||
|
for samples in data_loader:
|
||||||
|
image = samples["image"]
|
||||||
|
|
||||||
|
image = image.to(model.device)
|
||||||
|
image_feat, vit_feat = model.forward_image(image)
|
||||||
|
image_embed = model.vision_proj(image_feat)
|
||||||
|
image_embed = F.normalize(image_embed, dim=-1)
|
||||||
|
|
||||||
|
vit_feats.append(vit_feat.cpu())
|
||||||
|
image_embeds.append(image_embed)
|
||||||
|
|
||||||
|
vit_feats = torch.cat(vit_feats, dim=0)
|
||||||
|
image_embeds = torch.cat(image_embeds, dim=0)
|
||||||
|
|
||||||
|
sims_matrix = []
|
||||||
|
for image_embed in image_embeds:
|
||||||
|
sim_q2t = image_embed @ text_embeds.t()
|
||||||
|
sim_i2t, _ = sim_q2t.max(0)
|
||||||
|
sims_matrix.append(sim_i2t)
|
||||||
|
sims_matrix = torch.stack(sims_matrix, dim=0)
|
||||||
|
|
||||||
|
score_matrix_i2t = torch.full(
|
||||||
|
(len(data_loader.dataset.image), len(texts)), -100.0
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
num_tasks = dist_utils.get_world_size()
|
||||||
|
rank = dist_utils.get_rank()
|
||||||
|
step = sims_matrix.size(0) // num_tasks + 1
|
||||||
|
start = rank * step
|
||||||
|
end = min(sims_matrix.size(0), start + step)
|
||||||
|
|
||||||
|
for i, sims in enumerate(
|
||||||
|
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||||
|
):
|
||||||
|
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||||
|
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
|
||||||
|
score = model.compute_itm(
|
||||||
|
image_inputs=image_inputs,
|
||||||
|
text_ids=text_ids[topk_idx],
|
||||||
|
text_atts=text_atts[topk_idx],
|
||||||
|
).float()
|
||||||
|
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
|
||||||
|
|
||||||
|
sims_matrix = sims_matrix.t()
|
||||||
|
score_matrix_t2i = torch.full(
|
||||||
|
(len(texts), len(data_loader.dataset.image)), -100.0
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
step = sims_matrix.size(0) // num_tasks + 1
|
||||||
|
start = rank * step
|
||||||
|
end = min(sims_matrix.size(0), start + step)
|
||||||
|
|
||||||
|
for i, sims in enumerate(
|
||||||
|
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||||
|
):
|
||||||
|
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||||
|
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
|
||||||
|
score = model.compute_itm(
|
||||||
|
image_inputs=image_inputs,
|
||||||
|
text_ids=text_ids[start + i].repeat(k_test, 1),
|
||||||
|
text_atts=text_atts[start + i].repeat(k_test, 1),
|
||||||
|
).float()
|
||||||
|
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
|
||||||
|
|
||||||
|
if dist_utils.is_dist_avail_and_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
|
||||||
|
)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
logging.info("Evaluation time {}".format(total_time_str))
|
||||||
|
|
||||||
|
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
110
models/backbones/blip2_outputs.py
Executable file
110
models/backbones/blip2_outputs.py
Executable file
|
@ -0,0 +1,110 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
ModelOutput,
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlipSimilarity(ModelOutput):
|
||||||
|
sim_i2t: torch.FloatTensor = None
|
||||||
|
sim_t2i: torch.FloatTensor = None
|
||||||
|
|
||||||
|
sim_i2t_m: Optional[torch.FloatTensor] = None
|
||||||
|
sim_t2i_m: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
sim_i2t_targets: Optional[torch.FloatTensor] = None
|
||||||
|
sim_t2i_targets: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlipIntermediateOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Data class for intermediate outputs of BLIP models.
|
||||||
|
|
||||||
|
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
|
||||||
|
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
|
||||||
|
|
||||||
|
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
|
||||||
|
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
|
||||||
|
|
||||||
|
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
|
||||||
|
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
|
||||||
|
|
||||||
|
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
|
||||||
|
decoder_labels (torch.LongTensor): labels for the captioning loss.
|
||||||
|
|
||||||
|
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
|
||||||
|
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# uni-modal features
|
||||||
|
image_embeds: torch.FloatTensor = None
|
||||||
|
text_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
image_embeds_m: Optional[torch.FloatTensor] = None
|
||||||
|
text_embeds_m: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
# intermediate outputs of multimodal encoder
|
||||||
|
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||||
|
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||||
|
|
||||||
|
itm_logits: Optional[torch.FloatTensor] = None
|
||||||
|
itm_labels: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
|
# intermediate outputs of multimodal decoder
|
||||||
|
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
|
||||||
|
decoder_labels: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlipOutput(ModelOutput):
|
||||||
|
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
|
||||||
|
sims: Optional[BlipSimilarity] = None
|
||||||
|
|
||||||
|
intermediate_output: BlipIntermediateOutput = None
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
loss_itc: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
loss_itm: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
loss_lm: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlipOutputFeatures(ModelOutput):
|
||||||
|
"""
|
||||||
|
Data class of features from BlipFeatureExtractor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
|
||||||
|
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
|
||||||
|
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
|
||||||
|
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
|
||||||
|
|
||||||
|
The first embedding or feature is for the [CLS] token.
|
||||||
|
|
||||||
|
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
image_embeds_proj: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
text_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
text_embeds_proj: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
multimodal_embeds: Optional[torch.FloatTensor] = None
|
83
models/backbones/clip_vision_encoder.py
Normal file
83
models/backbones/clip_vision_encoder.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionEncoder(nn.Module):
|
||||||
|
def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
|
self.vision_encoder_name = encoder_name
|
||||||
|
# self.select_layer = args.mm_vision_select_layer
|
||||||
|
# self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||||
|
self.select_layer = -1
|
||||||
|
self.select_feature = "patch"
|
||||||
|
if not delay_load:
|
||||||
|
self.load_model()
|
||||||
|
else:
|
||||||
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
|
||||||
|
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
|
||||||
|
self.vision_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def feature_select(self, image_forward_outs):
|
||||||
|
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||||
|
if self.select_feature == 'patch':
|
||||||
|
image_features = image_features[:, :]
|
||||||
|
elif self.select_feature == 'cls_patch':
|
||||||
|
image_features = image_features
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, images):
|
||||||
|
if type(images) is list:
|
||||||
|
image_features = []
|
||||||
|
for image in images:
|
||||||
|
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
||||||
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||||
|
image_features.append(image_feature)
|
||||||
|
else:
|
||||||
|
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
||||||
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||||
|
# print("image feature shape", image_features.shape)
|
||||||
|
# print(type(image_forward_outs))
|
||||||
|
# print(type(image_forward_outs.shape))
|
||||||
|
# image_features = image_forward_outs.to(images.dtype)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_feature(self):
|
||||||
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.vision_encoder.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vision_encoder.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
if self.is_loaded:
|
||||||
|
return self.vision_encoder.config
|
||||||
|
else:
|
||||||
|
return self.cfg_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches(self):
|
||||||
|
return (self.config.image_size // self.config.patch_size) ** 2
|
141
models/backbones/encoder_decoder/builder.py
Normal file
141
models/backbones/encoder_decoder/builder.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
|
||||||
|
import glog as logger
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
||||||
|
from .xflan_t5 import T5Config, T5ForConditionalGeneration
|
||||||
|
from .xbart import BartConfig, BartForConditionalGeneration, BartEncoder, BartForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder_decoder(model_config):
|
||||||
|
"""build (encoder-) decoder model for answer generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info('[INFO] Loading Encoder Decoder [Type = {}]'.format(model_config['enc_dec_name']))
|
||||||
|
|
||||||
|
if model_config['enc_dec_family'] == 'flan_t5':
|
||||||
|
config_cls = T5Config
|
||||||
|
model_cls = T5ForConditionalGeneration
|
||||||
|
elif model_config['enc_dec_family'] == 'bart':
|
||||||
|
config_cls = BartConfig
|
||||||
|
if model_config['use_decoder_only']:
|
||||||
|
model_cls = BartForCausalLM
|
||||||
|
else:
|
||||||
|
model_cls = BartForConditionalGeneration
|
||||||
|
else:
|
||||||
|
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||||
|
enc_dec_config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||||
|
model_config['enc_dec_dim'] = enc_dec_config.d_model
|
||||||
|
# enc_dec_config.encoder_layers = enc_dec_config.encoder_layers - model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||||
|
enc_dec = model_cls.from_pretrained(
|
||||||
|
model_config['enc_dec_name'],
|
||||||
|
config=enc_dec_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# first_k = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||||
|
# enc_dec.model.encoder.remove_first_k_layers(first_k)
|
||||||
|
# get the last encoder layers
|
||||||
|
# enc_dec.
|
||||||
|
|
||||||
|
|
||||||
|
if model_config['use_lora_enc_dec']:
|
||||||
|
# load the lora config
|
||||||
|
with open(model_config['lora_config'], 'r') as f:
|
||||||
|
lora_config = json.load(f)
|
||||||
|
|
||||||
|
# get the linear layer to perform LoRA on
|
||||||
|
model_modules = str(enc_dec.modules)
|
||||||
|
pattern = r'\((\w+)\): Linear'
|
||||||
|
linear_layer_names = re.findall(pattern, model_modules)
|
||||||
|
|
||||||
|
names = []
|
||||||
|
# Print the names of the Linear layers
|
||||||
|
for name in linear_layer_names:
|
||||||
|
names.append(name)
|
||||||
|
target_modules = list(set(names))
|
||||||
|
|
||||||
|
lora_config['target_modules'] = target_modules
|
||||||
|
|
||||||
|
lora_config = LoraConfig(**lora_config)
|
||||||
|
|
||||||
|
enc_dec = get_peft_model(enc_dec, lora_config)
|
||||||
|
|
||||||
|
return enc_dec
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(model_config, expert_type, modality=None):
|
||||||
|
"""build (encoder-) decoder model for answer generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
log_txt = '[INFO] Loading {} Expert'.format(expert_type)
|
||||||
|
if modality is not None:
|
||||||
|
log_txt += ' [Modality = {}]'.format(modality)
|
||||||
|
log_txt += ' [Type = {}]'.format(model_config['enc_dec_name'])
|
||||||
|
|
||||||
|
logger.info(log_txt)
|
||||||
|
|
||||||
|
if model_config['enc_dec_family'] == 'flan_t5':
|
||||||
|
config_cls = T5Config
|
||||||
|
model_cls = T5ForConditionalGeneration
|
||||||
|
elif model_config['enc_dec_family'] == 'bart':
|
||||||
|
config_cls = BartConfig
|
||||||
|
model_cls = BartEncoder
|
||||||
|
else:
|
||||||
|
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||||
|
|
||||||
|
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||||
|
config.modality_expert_layers = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||||
|
config.grounding_expert_layers = model_config['num_layers_grounding_expert_{}'.format(model_config['enc_dec_family'])]
|
||||||
|
|
||||||
|
model_config['enc_dec_dim'] = config.d_model
|
||||||
|
|
||||||
|
expert = model_cls.from_pretrained(
|
||||||
|
model_config['enc_dec_name'],
|
||||||
|
config=config,
|
||||||
|
expert_type=expert_type,
|
||||||
|
modality=modality
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_config['use_lora_expert']:
|
||||||
|
# load the lora config
|
||||||
|
with open(model_config['lora_config'], 'r') as f:
|
||||||
|
lora_config = json.load(f)
|
||||||
|
|
||||||
|
# get the linear layer to perform LoRA on
|
||||||
|
model_modules = str(expert.modules)
|
||||||
|
pattern = r'\((\w+)\): Linear'
|
||||||
|
linear_layer_names = re.findall(pattern, model_modules)
|
||||||
|
|
||||||
|
names = []
|
||||||
|
# Print the names of the Linear layers
|
||||||
|
for name in linear_layer_names:
|
||||||
|
names.append(name)
|
||||||
|
target_modules = list(set(names))
|
||||||
|
|
||||||
|
lora_config['target_modules'] = target_modules
|
||||||
|
|
||||||
|
lora_config = LoraConfig(**lora_config)
|
||||||
|
|
||||||
|
expert = get_peft_model(expert, lora_config)
|
||||||
|
|
||||||
|
# expert = model_cls(
|
||||||
|
# config=config,
|
||||||
|
# expert_type=expert_type,
|
||||||
|
# modality=modality
|
||||||
|
# )
|
||||||
|
|
||||||
|
return expert
|
||||||
|
|
||||||
|
|
65
models/backbones/encoder_decoder/builder_orig.py
Normal file
65
models/backbones/encoder_decoder/builder_orig.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from .xflan_t5 import T5Config, T5ForConditionalGeneration
|
||||||
|
from .xbart_original import BartConfig, BartForConditionalGeneration, BartEncoder
|
||||||
|
|
||||||
|
import glog as logger
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder_decoder(model_config):
|
||||||
|
"""build (encoder-) decoder model for answer generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info('[INFO] Loading Encoder Decoder: {}'.format(model_config['enc_dec_name']))
|
||||||
|
|
||||||
|
if model_config['enc_dec_family'] == 'flan_t5':
|
||||||
|
config_cls = T5Config
|
||||||
|
model_cls = T5ForConditionalGeneration
|
||||||
|
elif model_config['enc_dec_family'] == 'bart':
|
||||||
|
config_cls = BartConfig
|
||||||
|
model_cls = BartForConditionalGeneration
|
||||||
|
else:
|
||||||
|
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||||
|
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||||
|
model_config['enc_dec_dim'] = config.d_model
|
||||||
|
enc_dec = model_cls.from_pretrained(
|
||||||
|
model_config['enc_dec_name'],
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
return enc_dec
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(model_config):
|
||||||
|
"""build (encoder-) decoder model for answer generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (dict): model config.
|
||||||
|
|
||||||
|
Returns: TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info('[INFO] Loading Expert as Encoder of {}'.format(model_config['enc_dec_name']))
|
||||||
|
|
||||||
|
if model_config['enc_dec_family'] == 'flan_t5':
|
||||||
|
config_cls = T5Config
|
||||||
|
model_cls = T5ForConditionalGeneration
|
||||||
|
elif model_config['enc_dec_family'] == 'bart':
|
||||||
|
config_cls = BartConfig
|
||||||
|
model_cls = BartEncoder
|
||||||
|
else:
|
||||||
|
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||||
|
|
||||||
|
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||||
|
model_config['enc_dec_dim'] = config.d_model
|
||||||
|
config.encoder_layers = model_config['num_layers_modality_expert']
|
||||||
|
|
||||||
|
expert = model_cls.from_pretrained(
|
||||||
|
model_config['enc_dec_name'],
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
return expert
|
19
models/backbones/encoder_decoder/outputs.py
Normal file
19
models/backbones/encoder_decoder/outputs.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seq2SeqV2DialOutput(ModelOutput):
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
2044
models/backbones/encoder_decoder/xbart.py
Normal file
2044
models/backbones/encoder_decoder/xbart.py
Normal file
File diff suppressed because it is too large
Load diff
1954
models/backbones/encoder_decoder/xbart_original.py
Normal file
1954
models/backbones/encoder_decoder/xbart_original.py
Normal file
File diff suppressed because it is too large
Load diff
2075
models/backbones/encoder_decoder/xflan_t5.py
Normal file
2075
models/backbones/encoder_decoder/xflan_t5.py
Normal file
File diff suppressed because it is too large
Load diff
455
models/backbones/eva_vit.py
Executable file
455
models/backbones/eva_vit.py
Executable file
|
@ -0,0 +1,455 @@
|
||||||
|
# Based on EVA, BEIT, timm and DeiT code bases
|
||||||
|
# https://github.com/baaivision/EVA
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||||
|
# https://github.com/microsoft/unilm/tree/master/beit
|
||||||
|
# https://github.com/facebookresearch/deit/
|
||||||
|
# https://github.com/facebookresearch/dino
|
||||||
|
# --------------------------------------------------------'
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
|
||||||
|
from models.common.dist_utils import download_cached_file
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return 'p={}'.format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
# x = self.drop(x)
|
||||||
|
# commit this for the orignal BERT implement
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
||||||
|
proj_drop=0., window_size=None, attn_head_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
if attn_head_dim is not None:
|
||||||
|
head_dim = attn_head_dim
|
||||||
|
all_head_dim = head_dim * self.num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||||
|
if qkv_bias:
|
||||||
|
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||||
|
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||||
|
else:
|
||||||
|
self.q_bias = None
|
||||||
|
self.v_bias = None
|
||||||
|
|
||||||
|
if window_size:
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||||
|
# cls to token & token 2 cls & cls to cls
|
||||||
|
|
||||||
|
# get pair-wise relative position index for each token inside the window
|
||||||
|
coords_h = torch.arange(window_size[0])
|
||||||
|
coords_w = torch.arange(window_size[1])
|
||||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||||
|
relative_coords[:, :, 1] += window_size[1] - 1
|
||||||
|
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||||
|
relative_position_index = \
|
||||||
|
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
||||||
|
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||||
|
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||||
|
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||||
|
|
||||||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
|
else:
|
||||||
|
self.window_size = None
|
||||||
|
self.relative_position_bias_table = None
|
||||||
|
self.relative_position_index = None
|
||||||
|
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(all_head_dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, rel_pos_bias=None):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv_bias = None
|
||||||
|
if self.q_bias is not None:
|
||||||
|
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||||
|
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||||
|
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
|
if self.relative_position_bias_table is not None:
|
||||||
|
relative_position_bias = \
|
||||||
|
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
|
self.window_size[0] * self.window_size[1] + 1,
|
||||||
|
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
if rel_pos_bias is not None:
|
||||||
|
attn = attn + rel_pos_bias
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||||
|
window_size=None, attn_head_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
if init_values is not None and init_values > 0:
|
||||||
|
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||||
|
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.gamma_1, self.gamma_2 = None, None
|
||||||
|
|
||||||
|
def forward(self, x, rel_pos_bias=None):
|
||||||
|
if self.gamma_1 is None:
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||||
|
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
""" Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||||
|
super().__init__()
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||||
|
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_patches = num_patches
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
# FIXME look at relaxing size constraints
|
||||||
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionBias(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, window_size, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||||
|
# cls to token & token 2 cls & cls to cls
|
||||||
|
|
||||||
|
# get pair-wise relative position index for each token inside the window
|
||||||
|
coords_h = torch.arange(window_size[0])
|
||||||
|
coords_w = torch.arange(window_size[1])
|
||||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||||
|
relative_coords[:, :, 1] += window_size[1] - 1
|
||||||
|
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||||
|
relative_position_index = \
|
||||||
|
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
||||||
|
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||||
|
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||||
|
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||||
|
|
||||||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
|
|
||||||
|
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
relative_position_bias = \
|
||||||
|
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
|
self.window_size[0] * self.window_size[1] + 1,
|
||||||
|
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||||
|
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
|
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||||
|
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
||||||
|
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
||||||
|
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
||||||
|
super().__init__()
|
||||||
|
self.image_size = img_size
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
if use_abs_pos_emb:
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
if use_shared_rel_pos_bias:
|
||||||
|
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
||||||
|
else:
|
||||||
|
self.rel_pos_bias = None
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
self.use_rel_pos_bias = use_rel_pos_bias
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Block(
|
||||||
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||||
|
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
||||||
|
for i in range(depth)])
|
||||||
|
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||||
|
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||||
|
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
# trunc_normal_(self.mask_token, std=.02)
|
||||||
|
# if isinstance(self.head, nn.Linear):
|
||||||
|
# trunc_normal_(self.head.weight, std=.02)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
self.fix_init_weight()
|
||||||
|
# if isinstance(self.head, nn.Linear):
|
||||||
|
# self.head.weight.data.mul_(init_scale)
|
||||||
|
# self.head.bias.data.mul_(init_scale)
|
||||||
|
|
||||||
|
def fix_init_weight(self):
|
||||||
|
def rescale(param, layer_id):
|
||||||
|
param.div_(math.sqrt(2.0 * layer_id))
|
||||||
|
|
||||||
|
for layer_id, layer in enumerate(self.blocks):
|
||||||
|
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||||
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=''):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
batch_size, seq_len, _ = x.size()
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||||
|
for blk in self.blocks:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||||
|
else:
|
||||||
|
x = blk(x, rel_pos_bias)
|
||||||
|
return x
|
||||||
|
# x = self.norm(x)
|
||||||
|
|
||||||
|
# if self.fc_norm is not None:
|
||||||
|
# t = x[:, 1:, :]
|
||||||
|
# return self.fc_norm(t.mean(1))
|
||||||
|
# else:
|
||||||
|
# return x[:, 0]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
# x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_intermediate_layers(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
batch_size, seq_len, _ = x.size()
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
features = []
|
||||||
|
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, rel_pos_bias)
|
||||||
|
features.append(x)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_num_layer(self, var_name=""):
|
||||||
|
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
||||||
|
return 0
|
||||||
|
elif var_name.startswith("patch_embed"):
|
||||||
|
return 0
|
||||||
|
elif var_name.startswith("rel_pos_bias"):
|
||||||
|
return len(self.blocks) - 1
|
||||||
|
elif var_name.startswith("blocks"):
|
||||||
|
layer_id = int(var_name.split('.')[1])
|
||||||
|
return layer_id + 1
|
||||||
|
else:
|
||||||
|
return len(self.blocks)
|
||||||
|
|
||||||
|
def interpolate_pos_embed(model, checkpoint_model):
|
||||||
|
if 'pos_embed' in checkpoint_model:
|
||||||
|
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
||||||
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||||
|
num_patches = model.patch_embed.num_patches
|
||||||
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(num_patches ** 0.5)
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
if orig_size != new_size:
|
||||||
|
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
||||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
checkpoint_model['pos_embed'] = new_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights_to_fp16(model: nn.Module):
|
||||||
|
"""Convert applicable model parameters to fp16"""
|
||||||
|
|
||||||
|
def _convert_weights_to_fp16(l):
|
||||||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||||
|
l.weight.data = l.weight.data.half()
|
||||||
|
if l.bias is not None:
|
||||||
|
l.bias.data = l.bias.data.half()
|
||||||
|
|
||||||
|
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||||
|
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||||
|
# tensor = getattr(l, attr)
|
||||||
|
# if tensor is not None:
|
||||||
|
# tensor.data = tensor.data.half()
|
||||||
|
|
||||||
|
model.apply(_convert_weights_to_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
|
||||||
|
model = VisionTransformer(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=14,
|
||||||
|
use_mean_pooling=False,
|
||||||
|
embed_dim=1408,
|
||||||
|
depth=39,
|
||||||
|
# depth = 37,
|
||||||
|
num_heads=1408//88,
|
||||||
|
mlp_ratio=4.3637,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||||
|
cached_file = download_cached_file(
|
||||||
|
url, check_hash=False, progress=True
|
||||||
|
)
|
||||||
|
state_dict = torch.load(cached_file, map_location="cpu")
|
||||||
|
interpolate_pos_embed(model,state_dict)
|
||||||
|
|
||||||
|
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
# print(incompatible_keys)
|
||||||
|
|
||||||
|
if precision == "fp16":
|
||||||
|
# model.to("cuda")
|
||||||
|
convert_weights_to_fp16(model)
|
||||||
|
return model
|
895
models/backbones/mini_gpt4_llama_v2.py
Executable file
895
models/backbones/mini_gpt4_llama_v2.py
Executable file
|
@ -0,0 +1,895 @@
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||||
|
# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
|
||||||
|
# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
get_peft_model_state_dict,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
set_peft_model_state_dict,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from minigpt4.models import policies
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("mini_gpt4_llama_v2")
|
||||||
|
class MiniGPT4_llama_v2(Blip2Base):
|
||||||
|
"""
|
||||||
|
BLIP2 GPT-LLAMA model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
|
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
img_size=224,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
llama_model="",
|
||||||
|
prompt_path="",
|
||||||
|
prompt_template="",
|
||||||
|
max_txt_len=32,
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
end_sym='\n',
|
||||||
|
lora_r = 8,
|
||||||
|
lora_target_modules = ["q_proj","v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
# lora_r = 16,
|
||||||
|
# lora_target_modules = ["q_proj","v_proj","v_proj"],
|
||||||
|
lora_dropout= 0.05,
|
||||||
|
ckpt_path = "",
|
||||||
|
system_prompt= False,
|
||||||
|
chat_template=False,
|
||||||
|
token_pooling=True,
|
||||||
|
use_grad_checkpoint_llm=False,
|
||||||
|
max_context_len=3800,
|
||||||
|
remove_template = False,
|
||||||
|
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if "Mistral" in llama_model:
|
||||||
|
from minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
|
||||||
|
print("Mistral model")
|
||||||
|
self.model_type = "Mistral"
|
||||||
|
else:
|
||||||
|
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
|
||||||
|
print("Llama model")
|
||||||
|
self.model_type = "Llama"
|
||||||
|
self.tokenizer = self.init_tokenizer()
|
||||||
|
self.low_resource = low_resource
|
||||||
|
self.token_pooling = token_pooling
|
||||||
|
self.remove_template = remove_template
|
||||||
|
|
||||||
|
print("token pooling", self.token_pooling)
|
||||||
|
|
||||||
|
|
||||||
|
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
|
||||||
|
self.max_context_len = max_context_len
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
|
# print('Loading VIT')
|
||||||
|
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||||
|
# )
|
||||||
|
|
||||||
|
if freeze_vit:
|
||||||
|
# vit_precision="fp32"
|
||||||
|
print("vit precision", vit_precision)
|
||||||
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||||
|
)
|
||||||
|
for name, param in self.visual_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.visual_encoder = self.visual_encoder.eval()
|
||||||
|
self.visual_encoder.train = disabled_train
|
||||||
|
for name, param in self.ln_vision.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.ln_vision = self.ln_vision.eval()
|
||||||
|
self.ln_vision.train = disabled_train
|
||||||
|
logging.info("freeze vision encoder")
|
||||||
|
print("freeze the vision encoder")
|
||||||
|
|
||||||
|
else:
|
||||||
|
vit_precision="fp32"
|
||||||
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||||
|
)
|
||||||
|
|
||||||
|
print("unfreeze the vision encoder")
|
||||||
|
|
||||||
|
print('Loading VIT Done')
|
||||||
|
|
||||||
|
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
|
||||||
|
# assert False
|
||||||
|
|
||||||
|
print('Loading LLAMA')
|
||||||
|
|
||||||
|
|
||||||
|
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||||
|
|
||||||
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model,use_fast=False) #
|
||||||
|
self.llama_tokenizer.pad_token = "$$"
|
||||||
|
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("self.low_resource",self.low_resource)
|
||||||
|
if self.low_resource:
|
||||||
|
self.llama_model = llm_model.from_pretrained(
|
||||||
|
llama_model,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
# torch_dtype = torch.bfloat16,
|
||||||
|
load_in_8bit=True,
|
||||||
|
# device_map = "balanced"
|
||||||
|
# device_map="auto",
|
||||||
|
device_map={'':torch.cuda.current_device()},
|
||||||
|
# device_map={'':0}
|
||||||
|
|
||||||
|
)
|
||||||
|
# bnb_config = BitsAndBytesConfig(
|
||||||
|
# load_in_4bit=True,
|
||||||
|
# bnb_4bit_use_double_quant=True,
|
||||||
|
# bnb_4bit_quant_type="nf4",
|
||||||
|
# bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
# )
|
||||||
|
# self.llama_model = llm_model.from_pretrained(
|
||||||
|
# llama_model,
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device_map={'':torch.cuda.current_device()},
|
||||||
|
# quantization_config=bnb_config,
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
self.llama_model = llm_model.from_pretrained(
|
||||||
|
llama_model,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
|
||||||
|
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
loraconfig = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=lora_target_modules,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM"
|
||||||
|
)
|
||||||
|
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||||
|
|
||||||
|
# if ckpt_path:
|
||||||
|
# print('load the llm under lora')
|
||||||
|
# ckpt = torch.load(ckpt_path)
|
||||||
|
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.llama_model.print_trainable_parameters()
|
||||||
|
|
||||||
|
if self.use_grad_checkpoint_llm:
|
||||||
|
self.llama_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# if not self.low_resource:
|
||||||
|
# for name, param in self.llama_model.named_parameters():
|
||||||
|
# if "embed_token" in name:
|
||||||
|
# param.data = param.data.float()
|
||||||
|
# param.requires_grad = True
|
||||||
|
|
||||||
|
|
||||||
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
|
|
||||||
|
if self.token_pooling:
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
1408*4, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
1408, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_txt_len = max_txt_len
|
||||||
|
self.end_sym = end_sym
|
||||||
|
|
||||||
|
if prompt_path:
|
||||||
|
with open(prompt_path, 'r') as f:
|
||||||
|
raw_prompts = f.read().splitlines()
|
||||||
|
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||||
|
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||||
|
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||||
|
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||||
|
else:
|
||||||
|
self.prompt_list = []
|
||||||
|
|
||||||
|
def encode_img(self, image):
|
||||||
|
device = image.device
|
||||||
|
if len(image.shape) > 4:
|
||||||
|
image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
|
||||||
|
with self.maybe_autocast():
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
|
||||||
|
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
||||||
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
|
def get_context_emb(self, prompt, img_list):
|
||||||
|
img_device = img_list[0].device
|
||||||
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
|
seg_tokens = [
|
||||||
|
self.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
|
||||||
|
for i, seg in enumerate(prompt_segs)
|
||||||
|
]
|
||||||
|
|
||||||
|
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
|
|
||||||
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
# # truncate the length of tokens to the max context window
|
||||||
|
# mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair]
|
||||||
|
# mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1)
|
||||||
|
# # check if the number of token in the second dimention is more than the max context window then truncate it
|
||||||
|
# context_window=self.max_context_len-seg_embs[-1].shape[1]
|
||||||
|
# if mixed_embs_without_instruction.shape[1] > context_window :
|
||||||
|
# mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window]
|
||||||
|
# mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1)
|
||||||
|
# print("mixed_embs",mixed_embs.shape)
|
||||||
|
|
||||||
|
return mixed_embs
|
||||||
|
|
||||||
|
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||||
|
if prompts is None or len(prompts) == 0:
|
||||||
|
# prompts is not provided, just return the original image embedding
|
||||||
|
return img_embeds, atts_img
|
||||||
|
elif img_embeds is None:
|
||||||
|
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
prompt_tokens = self.llama_tokenizer(
|
||||||
|
prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||||
|
atts_prompt = prompt_tokens.attention_mask
|
||||||
|
return prompt_embeds, atts_prompt
|
||||||
|
|
||||||
|
else:
|
||||||
|
# return the multi-modal embedding in right padding
|
||||||
|
emb_lists = []
|
||||||
|
|
||||||
|
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||||
|
pn = each_img_embed.shape[-2]
|
||||||
|
if lengths is not None:
|
||||||
|
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||||
|
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||||
|
|
||||||
|
p_segs = each_prompt.split('<ImageHere>')
|
||||||
|
interleave_emb = []
|
||||||
|
for idx, seg in enumerate(p_segs[:-1]):
|
||||||
|
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
# print("p_embed device",p_tokens.input_ids.device)
|
||||||
|
# print("p_tokens",img_embeds.device)
|
||||||
|
# print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
|
||||||
|
# print("model device",self.llama_model.get_device())
|
||||||
|
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
||||||
|
|
||||||
|
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||||
|
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
|
||||||
|
emb_lists.append(wrapped_emb)
|
||||||
|
|
||||||
|
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||||
|
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||||
|
|
||||||
|
# max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||||
|
max_length = self.max_context_len
|
||||||
|
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||||
|
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||||
|
|
||||||
|
for i, emb in enumerate(emb_lists):
|
||||||
|
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||||
|
wrapped_embs[i, :length] = emb[:, :length]
|
||||||
|
wrapped_atts[i, :length] = 1
|
||||||
|
|
||||||
|
return wrapped_embs, wrapped_atts
|
||||||
|
|
||||||
|
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||||
|
"""
|
||||||
|
Concatenate the batched input embedding and batched output embedding together.
|
||||||
|
Both the input and the output embedding should be right padded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_lens = []
|
||||||
|
cat_embs = []
|
||||||
|
cat_atts = []
|
||||||
|
|
||||||
|
for i in range(input_embs.size(0)):
|
||||||
|
input_len = input_atts[i].sum()
|
||||||
|
input_lens.append(input_len)
|
||||||
|
|
||||||
|
cat_embs.append(
|
||||||
|
torch.cat([
|
||||||
|
input_embs[i][:input_len],
|
||||||
|
output_embs[i],
|
||||||
|
input_embs[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_atts.append(
|
||||||
|
torch.cat([
|
||||||
|
input_atts[i][:input_len],
|
||||||
|
output_atts[i],
|
||||||
|
input_atts[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
# print('===================================')
|
||||||
|
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
|
||||||
|
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
|
||||||
|
# print('check out emb: ', output_embs[i][:2])
|
||||||
|
# print('check out pad emb: ', output_embs[i][-2:])
|
||||||
|
# print('+++++++++++++++++++++++++++++++++++')
|
||||||
|
#
|
||||||
|
# print('check attn before: ', input_atts[i][:this_input_ones])
|
||||||
|
# print('check attn after: ', input_atts[i][this_input_ones:])
|
||||||
|
# print('check attn gt before: ', output_atts[i][:3])
|
||||||
|
# print('check attn gt after: ', output_atts[i][-3:])
|
||||||
|
|
||||||
|
cat_embs = torch.stack(cat_embs)
|
||||||
|
cat_atts = torch.stack(cat_atts)
|
||||||
|
return cat_embs, cat_atts, input_lens
|
||||||
|
|
||||||
|
def get_conv_emb(self, conv_q, conv_a, conv_img):
|
||||||
|
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||||
|
|
||||||
|
regress_embs_list = []
|
||||||
|
targets_list = []
|
||||||
|
|
||||||
|
batch_size = len(conv_q)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||||
|
assigned_imgs = conv_img[batch_idx]
|
||||||
|
questions = [self.prompt_wrap(
|
||||||
|
img_embeds=img,
|
||||||
|
atts_img=None,
|
||||||
|
prompts=[q],
|
||||||
|
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
|
||||||
|
q_embs = [emb for emb, _ in questions]
|
||||||
|
|
||||||
|
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
|
||||||
|
cur_emb = []
|
||||||
|
cur_target = []
|
||||||
|
for i in range(len(questions)):
|
||||||
|
cur_emb.append(q_embs[i])
|
||||||
|
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
|
||||||
|
|
||||||
|
cur_emb.append(self.embed_tokens(answers[i].input_ids))
|
||||||
|
cur_target.append(answers[i].input_ids)
|
||||||
|
|
||||||
|
cur_emb = torch.cat(cur_emb, dim=1)
|
||||||
|
cur_target = torch.cat(cur_target, dim=1)
|
||||||
|
|
||||||
|
regress_embs_list.append(cur_emb)
|
||||||
|
targets_list.append(cur_target)
|
||||||
|
|
||||||
|
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||||
|
|
||||||
|
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
|
||||||
|
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
|
||||||
|
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
cur_len = regress_embs_list[batch_idx].shape[1]
|
||||||
|
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
|
||||||
|
regress_attn[batch_idx, :cur_len] = 1
|
||||||
|
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||||
|
|
||||||
|
return regress_embeds, regress_attn, targets
|
||||||
|
|
||||||
|
def preparing_embedding(self, samples):
|
||||||
|
def remove_special_tokens(data):
|
||||||
|
|
||||||
|
# if "instruction_input" in data:
|
||||||
|
data = [instruct.replace(" [caption]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [vqa]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [grounding]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [identify]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [refer]","") for instruct in data]
|
||||||
|
return data
|
||||||
|
|
||||||
|
### prepare input tokens
|
||||||
|
if 'image' in samples:
|
||||||
|
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||||
|
# print("img_embeds shape",img_embeds.shape)
|
||||||
|
else:
|
||||||
|
img_embeds = img_atts = None
|
||||||
|
|
||||||
|
if 'conv_q' in samples:
|
||||||
|
# handeling conversation datasets
|
||||||
|
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||||
|
|
||||||
|
connect_sym = samples['connect_sym'][0]
|
||||||
|
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||||
|
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||||
|
conv_img = assign_imgs(conv_q, img_embeds)
|
||||||
|
|
||||||
|
if self.chat_template:
|
||||||
|
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||||
|
|
||||||
|
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
|
||||||
|
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||||
|
|
||||||
|
# print("instruction before", instruction)
|
||||||
|
if self.remove_template:
|
||||||
|
instruction = remove_special_tokens(instruction)
|
||||||
|
# print("instruction after", instruction)
|
||||||
|
|
||||||
|
if self.chat_template:
|
||||||
|
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||||
|
|
||||||
|
if 'length' in samples:
|
||||||
|
# the input is a image train (like videos)
|
||||||
|
bsz, pn, hs = img_embeds.shape
|
||||||
|
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||||
|
else:
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||||
|
|
||||||
|
### prepare target tokens
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
text = [t + self.end_sym for t in samples["answer"]]
|
||||||
|
|
||||||
|
regress_tokens = self.llama_tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_txt_len,
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
regress_token_ids = regress_tokens.input_ids
|
||||||
|
regress_atts = regress_tokens.attention_mask
|
||||||
|
part_targets = regress_token_ids.masked_fill(
|
||||||
|
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||||
|
)
|
||||||
|
|
||||||
|
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||||
|
|
||||||
|
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||||
|
|
||||||
|
def forward(self, samples, reduction="mean"):
|
||||||
|
# prepare the embedding to condition and the embedding to regress
|
||||||
|
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||||
|
self.preparing_embedding(samples)
|
||||||
|
|
||||||
|
# concat the embedding to condition and the embedding to regress
|
||||||
|
inputs_embeds, attention_mask, input_lens = \
|
||||||
|
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||||
|
print("inputs_embeds shape",inputs_embeds.shape)
|
||||||
|
print("cond_embeds shape",cond_embeds.shape)
|
||||||
|
print("regress_embeds shape",regress_embeds.shape)
|
||||||
|
# get bos token embedding
|
||||||
|
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||||
|
bos_embeds = self.embed_tokens(bos)
|
||||||
|
bos_atts = attention_mask[:, :1]
|
||||||
|
|
||||||
|
# add bos token at the begining
|
||||||
|
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||||
|
# print length of instruction_input and answer words
|
||||||
|
# for i in range (len(samples["instruction_input"])):
|
||||||
|
# print("instruction_input length",len(samples["instruction_input"][i].split(" ")))
|
||||||
|
# print("answer length",len(samples["answer"][i].split(" ")))
|
||||||
|
# ensemble the final targets
|
||||||
|
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||||
|
dtype=torch.long).to(self.device).fill_(-100)
|
||||||
|
for i, target in enumerate(part_targets):
|
||||||
|
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||||
|
print("targets shape",targets.shape)
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
labels=targets,
|
||||||
|
reduction=reduction
|
||||||
|
)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
texts,
|
||||||
|
use_nucleus_sampling=False,
|
||||||
|
num_beams=1,
|
||||||
|
max_new_tokens=20,
|
||||||
|
min_length=1,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.5,
|
||||||
|
length_penalty=1,
|
||||||
|
temperature=1,
|
||||||
|
do_sample=False,
|
||||||
|
stop_words_ids=[2],
|
||||||
|
lengths=None,
|
||||||
|
return_video_temporal_features=False,
|
||||||
|
img_embeds=None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for generate test use
|
||||||
|
'''
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||||
|
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||||
|
if img_embeds is None:
|
||||||
|
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||||
|
else:
|
||||||
|
# Use images features from the input(4,45,64,5632)
|
||||||
|
img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
|
||||||
|
img_embeds= img_embeds.to(self.device)
|
||||||
|
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
||||||
|
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||||
|
|
||||||
|
print("img_embeds shape",img_embeds.shape)
|
||||||
|
if lengths is not None:
|
||||||
|
image_lists = []
|
||||||
|
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
||||||
|
for idx, img_embed in enumerate(img_embeds):
|
||||||
|
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
|
||||||
|
else:
|
||||||
|
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||||
|
assert len(texts) == len(image_lists)
|
||||||
|
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||||
|
|
||||||
|
batch_size = len(batch_embs)
|
||||||
|
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||||
|
emb_dim = batch_embs[0].shape[2]
|
||||||
|
dtype = batch_embs[0].dtype
|
||||||
|
device = batch_embs[0].device
|
||||||
|
|
||||||
|
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||||
|
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||||
|
for i, emb in enumerate(batch_embs):
|
||||||
|
emb_len = emb.shape[1]
|
||||||
|
embs[i, -emb_len:] = emb[0]
|
||||||
|
attn_mask[i, -emb_len:] = 1
|
||||||
|
# print("inputs_embeds shape",embs.shape)
|
||||||
|
# print("attention_mask shape",attn_mask.shape)
|
||||||
|
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
|
||||||
|
if self.model_type == "Llama":
|
||||||
|
context_window = 3700
|
||||||
|
else:
|
||||||
|
context_window = 7500
|
||||||
|
if embs.shape[1] > context_window:
|
||||||
|
embs = embs[:, -context_window:]
|
||||||
|
attn_mask = attn_mask[:, -context_window:]
|
||||||
|
print("inputs_embeds shape",embs.shape)
|
||||||
|
print("attention_mask shape",attn_mask.shape)
|
||||||
|
with self.maybe_autocast():
|
||||||
|
if return_video_temporal_features:
|
||||||
|
last_hidden_state = self.llama_model(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-1]
|
||||||
|
video_temporal_features = last_hidden_state.mean(dim=1)
|
||||||
|
# normalize the temporal features using L2 norm
|
||||||
|
# video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
|
||||||
|
outputs = self.llama_model.generate(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=do_sample,
|
||||||
|
temperature=temperature,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
# stopping_criteria=stopping_criteria,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for output_token in outputs:
|
||||||
|
if output_token[0] == 0:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||||
|
output_texts = output_texts.replace("<s>", "")
|
||||||
|
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||||
|
answers.append(output_texts)
|
||||||
|
if return_video_temporal_features:
|
||||||
|
return answers, video_temporal_features
|
||||||
|
else:
|
||||||
|
return answers
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_text_only(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
seg_tokens,
|
||||||
|
use_nucleus_sampling=False,
|
||||||
|
num_beams=1,
|
||||||
|
max_new_tokens=20,
|
||||||
|
min_length=1,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.5,
|
||||||
|
length_penalty=1,
|
||||||
|
temperature=1,
|
||||||
|
do_sample=False,
|
||||||
|
stop_words_ids=[2],
|
||||||
|
lengths=None,
|
||||||
|
return_video_temporal_features=False,
|
||||||
|
img_embeds=None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for generate test use
|
||||||
|
'''
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||||
|
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||||
|
|
||||||
|
# seg_tokens=[]
|
||||||
|
# for i, text in enumerate(texts):
|
||||||
|
# seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids)
|
||||||
|
|
||||||
|
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
|
||||||
|
|
||||||
|
# seg_embs = torch.cat(seg_embs, dim=1)
|
||||||
|
# print("seg_embs shape",seg_embs.shape)
|
||||||
|
# batch_embs=[seg_embs]
|
||||||
|
batch_size = len(batch_embs)
|
||||||
|
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||||
|
emb_dim = batch_embs[0].shape[2]
|
||||||
|
dtype = batch_embs[0].dtype
|
||||||
|
device = batch_embs[0].device
|
||||||
|
|
||||||
|
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||||
|
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||||
|
for i, emb in enumerate(batch_embs):
|
||||||
|
emb_len = emb.shape[1]
|
||||||
|
embs[i, -emb_len:] = emb[0]
|
||||||
|
attn_mask[i, -emb_len:] = 1
|
||||||
|
|
||||||
|
|
||||||
|
print("inputs_embeds shape",embs.shape)
|
||||||
|
print("attention_mask shape",attn_mask.shape)
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model.generate(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=do_sample,
|
||||||
|
temperature=temperature,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
# stopping_criteria=stopping_criteria,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for output_token in outputs:
|
||||||
|
if output_token[0] == 0:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||||
|
output_texts = output_texts.replace("<s>", "")
|
||||||
|
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||||
|
answers.append(output_texts)
|
||||||
|
return answers
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def multi_select(self, images, texts, answers, num_cand=None):
|
||||||
|
all_losses = []
|
||||||
|
for answer in answers:
|
||||||
|
choice_samples = {
|
||||||
|
'image': images,
|
||||||
|
'instruction_input': texts,
|
||||||
|
'answer': answer
|
||||||
|
}
|
||||||
|
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||||
|
all_losses.append(loss)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
all_losses = torch.cat(all_losses, dim=-1)
|
||||||
|
if num_cand is not None:
|
||||||
|
for i in range(all_losses.shape[0]):
|
||||||
|
all_losses[i, num_cand[i]:] = 9999
|
||||||
|
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||||
|
return output_class_ranks.tolist()
|
||||||
|
|
||||||
|
def predict_answers(
|
||||||
|
self,
|
||||||
|
samples,
|
||||||
|
num_beams=5,
|
||||||
|
inference_method="generate",
|
||||||
|
max_len=10,
|
||||||
|
min_len=1,
|
||||||
|
num_ans_candidates=128,
|
||||||
|
answer_list=None,
|
||||||
|
prompt="",
|
||||||
|
length_penalty=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for open-ended VQA
|
||||||
|
'''
|
||||||
|
images = samples["image"].cuda()
|
||||||
|
texts = samples["instruction_input"]
|
||||||
|
|
||||||
|
output_text = self.generate(
|
||||||
|
images=images,
|
||||||
|
texts=texts,
|
||||||
|
num_beams=num_beams,
|
||||||
|
max_new_tokens=max_len,
|
||||||
|
min_length=min_len,
|
||||||
|
length_penalty=length_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
||||||
|
output_text = self._lemmatize(output_text)
|
||||||
|
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def predict_class(
|
||||||
|
self,
|
||||||
|
samples,
|
||||||
|
num_beams=5,
|
||||||
|
inference_method="generate",
|
||||||
|
max_len=10,
|
||||||
|
min_len=1,
|
||||||
|
num_ans_candidates=5,
|
||||||
|
answer_list=None,
|
||||||
|
prompt="",
|
||||||
|
length_penalty=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for multi-choice VQA
|
||||||
|
'''
|
||||||
|
|
||||||
|
image = samples["image"].cuda()
|
||||||
|
instruction = samples['instruction_input']
|
||||||
|
answers = samples["choices"]
|
||||||
|
num_cand = samples["num_choices"]
|
||||||
|
|
||||||
|
ranks = self.multi_select(image, instruction, answers, num_cand)
|
||||||
|
|
||||||
|
pred_ans = []
|
||||||
|
for i, rank in enumerate(ranks):
|
||||||
|
pred = answers[rank[0]][i]
|
||||||
|
pred_ans.append(pred)
|
||||||
|
return pred_ans
|
||||||
|
|
||||||
|
def embed_tokens(self, token_ids):
|
||||||
|
try:
|
||||||
|
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||||
|
except AttributeError:
|
||||||
|
embeds = self.llama_model.model.embed_tokens(token_ids)
|
||||||
|
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg):
|
||||||
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
|
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||||
|
img_size = cfg.get("image_size")
|
||||||
|
num_query_token = cfg.get("num_query_token")
|
||||||
|
llama_model = cfg.get("llama_model")
|
||||||
|
|
||||||
|
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||||
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
|
||||||
|
prompt_path = cfg.get("prompt_path", "")
|
||||||
|
prompt_template = cfg.get("prompt_template", "")
|
||||||
|
max_txt_len = cfg.get("max_txt_len", 300)
|
||||||
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
lora_r = cfg.get("lora_r",64)
|
||||||
|
lora_alpha = cfg.get("lora_alpha",16)
|
||||||
|
chat_template = cfg.get("chat_template",False)
|
||||||
|
system_prompt = cfg.get("system_prompt", False)
|
||||||
|
token_pooling = cfg.get("token_pooling",True)
|
||||||
|
|
||||||
|
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||||
|
max_context_len = cfg.get("max_context_len", 3800)
|
||||||
|
remove_template = cfg.get("remove_template", False)
|
||||||
|
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
prompt_path=prompt_path,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
low_resource=low_resource,
|
||||||
|
end_sym=end_sym,
|
||||||
|
lora_r = lora_r,
|
||||||
|
lora_alpha = lora_alpha,
|
||||||
|
chat_template = chat_template,
|
||||||
|
system_prompt = system_prompt,
|
||||||
|
token_pooling = token_pooling,
|
||||||
|
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
remove_template = remove_template
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
if ckpt_path:
|
||||||
|
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def assign_imgs(batched_instruct_list, batched_img_embeds):
|
||||||
|
'''this function is used when the data is interleaved.
|
||||||
|
the interlevaed data is separated, and this function assign
|
||||||
|
corresponding image embeddings to each segment'''
|
||||||
|
if len(batched_img_embeds.shape) == 3:
|
||||||
|
batched_img_embeds = batched_img_embeds[:, None]
|
||||||
|
|
||||||
|
batched_assigned = []
|
||||||
|
|
||||||
|
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
|
||||||
|
img_idx = 0
|
||||||
|
assigned_img = []
|
||||||
|
n_assigned = []
|
||||||
|
for instruct in instruct_list:
|
||||||
|
n_img = instruct.count('<ImageHere>')
|
||||||
|
if n_img > 0: # this instruction include images.
|
||||||
|
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
|
||||||
|
img_idx += n_img
|
||||||
|
n_assigned.append(n_img)
|
||||||
|
else: # this instruction doesn't include images
|
||||||
|
assigned_img.append(None)
|
||||||
|
n_assigned.append(None)
|
||||||
|
batched_assigned.append(assigned_img)
|
||||||
|
|
||||||
|
return batched_assigned
|
709
models/backbones/mini_gpt4v.py
Executable file
709
models/backbones/mini_gpt4v.py
Executable file
|
@ -0,0 +1,709 @@
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||||
|
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig
|
||||||
|
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
prepare_model_for_kbit_training
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from minigpt4.models import policies
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("mini_gpt4v")
|
||||||
|
class MiniGPT4v(Blip2Base):
|
||||||
|
"""
|
||||||
|
BLIP2 GPT-LLAMA model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
|
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
img_size=224,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
llama_model="",
|
||||||
|
prompt_path="",
|
||||||
|
prompt_template="",
|
||||||
|
max_txt_len=32,
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
end_sym='\n',
|
||||||
|
lora_r = 8,
|
||||||
|
lora_target_modules = ["q_proj","v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
# lora_r = 16,
|
||||||
|
# lora_target_modules = ["q_proj","v_proj","v_proj"],
|
||||||
|
lora_dropout= 0.05,
|
||||||
|
ckpt_path = "",
|
||||||
|
system_prompt= False,
|
||||||
|
chat_template=False,
|
||||||
|
token_pooling=True,
|
||||||
|
use_grad_checkpoint_llm=False,
|
||||||
|
max_context_len=3800,
|
||||||
|
remove_template = False,
|
||||||
|
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.tokenizer = self.init_tokenizer()
|
||||||
|
self.low_resource = low_resource
|
||||||
|
self.token_pooling = token_pooling
|
||||||
|
self.remove_template = remove_template
|
||||||
|
|
||||||
|
print("token pooling", self.token_pooling)
|
||||||
|
|
||||||
|
|
||||||
|
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
|
||||||
|
self.max_context_len = max_context_len
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
|
# print('Loading VIT')
|
||||||
|
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
print("vit precision", vit_precision)
|
||||||
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||||
|
)
|
||||||
|
for name, param in self.visual_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.visual_encoder = self.visual_encoder.eval()
|
||||||
|
self.visual_encoder.train = disabled_train
|
||||||
|
for name, param in self.ln_vision.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.ln_vision = self.ln_vision.eval()
|
||||||
|
self.ln_vision.train = disabled_train
|
||||||
|
logging.info("freeze vision encoder")
|
||||||
|
print("freeze the vision encoder")
|
||||||
|
|
||||||
|
|
||||||
|
print('Loading VIT Done')
|
||||||
|
|
||||||
|
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
|
||||||
|
# assert False
|
||||||
|
|
||||||
|
print('Loading LLAMA')
|
||||||
|
|
||||||
|
|
||||||
|
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||||
|
|
||||||
|
if 'CodeLlama' in llama_model:
|
||||||
|
self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
|
||||||
|
self.llama_tokenizer.pad_token = "$$"
|
||||||
|
else:
|
||||||
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
|
||||||
|
self.llama_tokenizer.pad_token = "$$"
|
||||||
|
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map={"": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.llama_model.gradient_checkpointing_enable()
|
||||||
|
self.llama_model = prepare_model_for_kbit_training(self.llama_model)
|
||||||
|
|
||||||
|
# self.llama_model.print_trainable_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
|
self.merge_n = 3
|
||||||
|
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
1408 * self.merge_n**2, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_txt_len = max_txt_len
|
||||||
|
self.end_sym = end_sym
|
||||||
|
|
||||||
|
if prompt_path:
|
||||||
|
with open(prompt_path, 'r') as f:
|
||||||
|
raw_prompts = f.read().splitlines()
|
||||||
|
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||||
|
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||||
|
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||||
|
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||||
|
else:
|
||||||
|
self.prompt_list = []
|
||||||
|
|
||||||
|
def encode_img(self, image):
|
||||||
|
device = image.device
|
||||||
|
if len(image.shape) > 4:
|
||||||
|
image = image.reshape(-1, *image.shape[-3:])
|
||||||
|
|
||||||
|
bs, ch, w, h = image.shape
|
||||||
|
assert w % 224 == 0
|
||||||
|
bw = w // 224
|
||||||
|
assert h % 224 == 0
|
||||||
|
bh = h // 224
|
||||||
|
image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224
|
||||||
|
image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)
|
||||||
|
|
||||||
|
image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
|
||||||
|
image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs
|
||||||
|
image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])
|
||||||
|
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds)
|
||||||
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
|
def get_context_emb(self, prompt, img_list):
|
||||||
|
img_device = img_list[0].device
|
||||||
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
|
seg_tokens = [
|
||||||
|
self.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
|
||||||
|
for i, seg in enumerate(prompt_segs)
|
||||||
|
]
|
||||||
|
|
||||||
|
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
|
|
||||||
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
return mixed_embs
|
||||||
|
|
||||||
|
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||||
|
if prompts is None or len(prompts) == 0:
|
||||||
|
# prompts is not provided, just return the original image embedding
|
||||||
|
return img_embeds, atts_img
|
||||||
|
elif img_embeds is None:
|
||||||
|
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
prompt_tokens = self.llama_tokenizer(
|
||||||
|
prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||||
|
atts_prompt = prompt_tokens.attention_mask
|
||||||
|
return prompt_embeds, atts_prompt
|
||||||
|
|
||||||
|
else:
|
||||||
|
# return the multi-modal embedding in right padding
|
||||||
|
emb_lists = []
|
||||||
|
|
||||||
|
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||||
|
pn = each_img_embed.shape[-2]
|
||||||
|
if lengths is not None:
|
||||||
|
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||||
|
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||||
|
|
||||||
|
p_segs = each_prompt.split('<ImageHere>')
|
||||||
|
interleave_emb = []
|
||||||
|
for idx, seg in enumerate(p_segs[:-1]):
|
||||||
|
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
||||||
|
|
||||||
|
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||||
|
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
|
||||||
|
emb_lists.append(wrapped_emb)
|
||||||
|
|
||||||
|
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||||
|
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||||
|
|
||||||
|
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||||
|
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||||
|
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||||
|
|
||||||
|
for i, emb in enumerate(emb_lists):
|
||||||
|
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||||
|
wrapped_embs[i, :length] = emb[:, :length]
|
||||||
|
wrapped_atts[i, :length] = 1
|
||||||
|
|
||||||
|
return wrapped_embs, wrapped_atts
|
||||||
|
|
||||||
|
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||||
|
"""
|
||||||
|
Concatenate the batched input embedding and batched output embedding together.
|
||||||
|
Both the input and the output embedding should be right padded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_lens = []
|
||||||
|
cat_embs = []
|
||||||
|
cat_atts = []
|
||||||
|
|
||||||
|
for i in range(input_embs.size(0)):
|
||||||
|
input_len = input_atts[i].sum()
|
||||||
|
input_lens.append(input_len)
|
||||||
|
|
||||||
|
cat_embs.append(
|
||||||
|
torch.cat([
|
||||||
|
input_embs[i][:input_len],
|
||||||
|
output_embs[i],
|
||||||
|
input_embs[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_atts.append(
|
||||||
|
torch.cat([
|
||||||
|
input_atts[i][:input_len],
|
||||||
|
output_atts[i],
|
||||||
|
input_atts[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
# print('===================================')
|
||||||
|
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
|
||||||
|
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
|
||||||
|
# print('check out emb: ', output_embs[i][:2])
|
||||||
|
# print('check out pad emb: ', output_embs[i][-2:])
|
||||||
|
# print('+++++++++++++++++++++++++++++++++++')
|
||||||
|
#
|
||||||
|
# print('check attn before: ', input_atts[i][:this_input_ones])
|
||||||
|
# print('check attn after: ', input_atts[i][this_input_ones:])
|
||||||
|
# print('check attn gt before: ', output_atts[i][:3])
|
||||||
|
# print('check attn gt after: ', output_atts[i][-3:])
|
||||||
|
|
||||||
|
cat_embs = torch.stack(cat_embs)
|
||||||
|
cat_atts = torch.stack(cat_atts)
|
||||||
|
return cat_embs, cat_atts, input_lens
|
||||||
|
|
||||||
|
def get_conv_emb(self, conv_q, conv_a, conv_img):
|
||||||
|
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||||
|
|
||||||
|
regress_embs_list = []
|
||||||
|
targets_list = []
|
||||||
|
|
||||||
|
batch_size = len(conv_q)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||||
|
assigned_imgs = conv_img[batch_idx]
|
||||||
|
questions = [self.prompt_wrap(
|
||||||
|
img_embeds=img,
|
||||||
|
atts_img=None,
|
||||||
|
prompts=[q],
|
||||||
|
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
|
||||||
|
q_embs = [emb for emb, _ in questions]
|
||||||
|
|
||||||
|
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
|
||||||
|
cur_emb = []
|
||||||
|
cur_target = []
|
||||||
|
for i in range(len(questions)):
|
||||||
|
cur_emb.append(q_embs[i])
|
||||||
|
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
|
||||||
|
|
||||||
|
cur_emb.append(self.embed_tokens(answers[i].input_ids))
|
||||||
|
cur_target.append(answers[i].input_ids)
|
||||||
|
|
||||||
|
cur_emb = torch.cat(cur_emb, dim=1)
|
||||||
|
cur_target = torch.cat(cur_target, dim=1)
|
||||||
|
|
||||||
|
regress_embs_list.append(cur_emb)
|
||||||
|
targets_list.append(cur_target)
|
||||||
|
|
||||||
|
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||||
|
|
||||||
|
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
|
||||||
|
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
|
||||||
|
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
cur_len = regress_embs_list[batch_idx].shape[1]
|
||||||
|
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
|
||||||
|
regress_attn[batch_idx, :cur_len] = 1
|
||||||
|
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||||
|
|
||||||
|
return regress_embeds, regress_attn, targets
|
||||||
|
|
||||||
|
def preparing_embedding(self, samples):
|
||||||
|
def remove_special_tokens(data):
|
||||||
|
|
||||||
|
# if "instruction_input" in data:
|
||||||
|
data = [instruct.replace(" [caption]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [vqa]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [grounding]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [identify]","") for instruct in data]
|
||||||
|
data = [instruct.replace(" [refer]","") for instruct in data]
|
||||||
|
return data
|
||||||
|
|
||||||
|
### prepare input tokens
|
||||||
|
if 'image' in samples:
|
||||||
|
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||||
|
else:
|
||||||
|
img_embeds = img_atts = None
|
||||||
|
|
||||||
|
if 'conv_q' in samples:
|
||||||
|
# handeling conversation datasets
|
||||||
|
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||||
|
|
||||||
|
connect_sym = samples['connect_sym'][0]
|
||||||
|
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||||
|
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||||
|
conv_img = assign_imgs(conv_q, img_embeds)
|
||||||
|
|
||||||
|
if self.chat_template:
|
||||||
|
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||||
|
|
||||||
|
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
|
||||||
|
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||||
|
|
||||||
|
# print("instruction before", instruction)
|
||||||
|
if self.remove_template:
|
||||||
|
instruction = remove_special_tokens(instruction)
|
||||||
|
# print("instruction after", instruction)
|
||||||
|
|
||||||
|
if self.chat_template:
|
||||||
|
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||||
|
|
||||||
|
if 'length' in samples:
|
||||||
|
# the input is a image train (like videos)
|
||||||
|
bsz, pn, hs = img_embeds.shape
|
||||||
|
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||||
|
else:
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||||
|
|
||||||
|
### prepare target tokens
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
text = [t + self.end_sym for t in samples["answer"]]
|
||||||
|
|
||||||
|
regress_tokens = self.llama_tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_txt_len,
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
regress_token_ids = regress_tokens.input_ids
|
||||||
|
regress_atts = regress_tokens.attention_mask
|
||||||
|
part_targets = regress_token_ids.masked_fill(
|
||||||
|
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||||
|
)
|
||||||
|
|
||||||
|
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||||
|
|
||||||
|
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||||
|
|
||||||
|
def forward(self, samples, reduction="mean"):
|
||||||
|
# prepare the embedding to condition and the embedding to regress
|
||||||
|
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||||
|
self.preparing_embedding(samples)
|
||||||
|
|
||||||
|
# concat the embedding to condition and the embedding to regress
|
||||||
|
inputs_embeds, attention_mask, input_lens = \
|
||||||
|
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||||
|
|
||||||
|
# get bos token embedding
|
||||||
|
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||||
|
bos_embeds = self.embed_tokens(bos)
|
||||||
|
bos_atts = attention_mask[:, :1]
|
||||||
|
|
||||||
|
# add bos token at the begining
|
||||||
|
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||||
|
|
||||||
|
# ensemble the final targets
|
||||||
|
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||||
|
dtype=torch.long).to(self.device).fill_(-100)
|
||||||
|
for i, target in enumerate(part_targets):
|
||||||
|
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
labels=targets,
|
||||||
|
reduction=reduction
|
||||||
|
)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
texts,
|
||||||
|
use_nucleus_sampling=False,
|
||||||
|
num_beams=1,
|
||||||
|
max_new_tokens=20,
|
||||||
|
min_length=1,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1,
|
||||||
|
length_penalty=1,
|
||||||
|
temperature=1,
|
||||||
|
do_sample=False,
|
||||||
|
stop_words_ids=[2],
|
||||||
|
lengths=None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for generate test use
|
||||||
|
'''
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||||
|
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||||
|
|
||||||
|
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||||
|
if lengths is not None:
|
||||||
|
image_lists = []
|
||||||
|
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
||||||
|
for idx, img_embed in enumerate(img_embeds):
|
||||||
|
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
|
||||||
|
else:
|
||||||
|
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||||
|
assert len(texts) == len(image_lists)
|
||||||
|
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||||
|
|
||||||
|
batch_size = len(batch_embs)
|
||||||
|
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||||
|
emb_dim = batch_embs[0].shape[2]
|
||||||
|
dtype = batch_embs[0].dtype
|
||||||
|
device = batch_embs[0].device
|
||||||
|
|
||||||
|
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||||
|
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||||
|
for i, emb in enumerate(batch_embs):
|
||||||
|
emb_len = emb.shape[1]
|
||||||
|
embs[i, -emb_len:] = emb[0]
|
||||||
|
attn_mask[i, -emb_len:] = 1
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model.generate(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=do_sample,
|
||||||
|
# stopping_criteria=stopping_criteria,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for output_token in outputs:
|
||||||
|
if output_token[0] == 0:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||||
|
output_texts = output_texts.replace("<s>", "")
|
||||||
|
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||||
|
answers.append(output_texts)
|
||||||
|
|
||||||
|
return answers
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def multi_select(self, images, texts, answers, num_cand=None):
|
||||||
|
all_losses = []
|
||||||
|
for answer in answers:
|
||||||
|
choice_samples = {
|
||||||
|
'image': images,
|
||||||
|
'instruction_input': texts,
|
||||||
|
'answer': answer
|
||||||
|
}
|
||||||
|
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||||
|
all_losses.append(loss)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
all_losses = torch.cat(all_losses, dim=-1)
|
||||||
|
if num_cand is not None:
|
||||||
|
for i in range(all_losses.shape[0]):
|
||||||
|
all_losses[i, num_cand[i]:] = 9999
|
||||||
|
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||||
|
return output_class_ranks.tolist()
|
||||||
|
|
||||||
|
def predict_answers(
|
||||||
|
self,
|
||||||
|
samples,
|
||||||
|
num_beams=5,
|
||||||
|
inference_method="generate",
|
||||||
|
max_len=10,
|
||||||
|
min_len=1,
|
||||||
|
num_ans_candidates=128,
|
||||||
|
answer_list=None,
|
||||||
|
prompt="",
|
||||||
|
length_penalty=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for open-ended VQA
|
||||||
|
'''
|
||||||
|
images = samples["image"].cuda()
|
||||||
|
texts = samples["instruction_input"]
|
||||||
|
|
||||||
|
output_text = self.generate(
|
||||||
|
images=images,
|
||||||
|
texts=texts,
|
||||||
|
num_beams=num_beams,
|
||||||
|
max_new_tokens=max_len,
|
||||||
|
min_length=min_len,
|
||||||
|
length_penalty=length_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
||||||
|
output_text = self._lemmatize(output_text)
|
||||||
|
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def predict_class(
|
||||||
|
self,
|
||||||
|
samples,
|
||||||
|
num_beams=5,
|
||||||
|
inference_method="generate",
|
||||||
|
max_len=10,
|
||||||
|
min_len=1,
|
||||||
|
num_ans_candidates=5,
|
||||||
|
answer_list=None,
|
||||||
|
prompt="",
|
||||||
|
length_penalty=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for multi-choice VQA
|
||||||
|
'''
|
||||||
|
|
||||||
|
image = samples["image"].cuda()
|
||||||
|
instruction = samples['instruction_input']
|
||||||
|
answers = samples["choices"]
|
||||||
|
num_cand = samples["num_choices"]
|
||||||
|
|
||||||
|
ranks = self.multi_select(image, instruction, answers, num_cand)
|
||||||
|
|
||||||
|
pred_ans = []
|
||||||
|
for i, rank in enumerate(ranks):
|
||||||
|
pred = answers[rank[0]][i]
|
||||||
|
pred_ans.append(pred)
|
||||||
|
return pred_ans
|
||||||
|
|
||||||
|
def embed_tokens(self, token_ids):
|
||||||
|
try:
|
||||||
|
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||||
|
except AttributeError:
|
||||||
|
embeds = self.llama_model.model.embed_tokens(token_ids)
|
||||||
|
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg):
|
||||||
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
|
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||||
|
img_size = cfg.get("image_size")
|
||||||
|
num_query_token = cfg.get("num_query_token")
|
||||||
|
llama_model = cfg.get("llama_model")
|
||||||
|
|
||||||
|
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||||
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
|
||||||
|
prompt_path = cfg.get("prompt_path", "")
|
||||||
|
prompt_template = cfg.get("prompt_template", "")
|
||||||
|
max_txt_len = cfg.get("max_txt_len", 300)
|
||||||
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
lora_r = cfg.get("lora_r",64)
|
||||||
|
lora_alpha = cfg.get("lora_alpha",16)
|
||||||
|
chat_template = cfg.get("chat_template",False)
|
||||||
|
system_prompt = cfg.get("system_prompt", False)
|
||||||
|
token_pooling = cfg.get("token_pooling",True)
|
||||||
|
|
||||||
|
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||||
|
max_context_len = cfg.get("max_context_len", 3800)
|
||||||
|
remove_template = cfg.get("remove_template", False)
|
||||||
|
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
prompt_path=prompt_path,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
low_resource=low_resource,
|
||||||
|
end_sym=end_sym,
|
||||||
|
lora_r = lora_r,
|
||||||
|
lora_alpha = lora_alpha,
|
||||||
|
chat_template = chat_template,
|
||||||
|
system_prompt = system_prompt,
|
||||||
|
token_pooling = token_pooling,
|
||||||
|
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
remove_template = remove_template
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
if ckpt_path:
|
||||||
|
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def assign_imgs(batched_instruct_list, batched_img_embeds):
|
||||||
|
'''this function is used when the data is interleaved.
|
||||||
|
the interlevaed data is separated, and this function assign
|
||||||
|
corresponding image embeddings to each segment'''
|
||||||
|
if len(batched_img_embeds.shape) == 3:
|
||||||
|
batched_img_embeds = batched_img_embeds[:, None]
|
||||||
|
|
||||||
|
batched_assigned = []
|
||||||
|
|
||||||
|
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
|
||||||
|
img_idx = 0
|
||||||
|
assigned_img = []
|
||||||
|
n_assigned = []
|
||||||
|
for instruct in instruct_list:
|
||||||
|
n_img = instruct.count('<ImageHere>')
|
||||||
|
if n_img > 0: # this instruction include images.
|
||||||
|
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
|
||||||
|
img_idx += n_img
|
||||||
|
n_assigned.append(n_img)
|
||||||
|
else: # this instruction doesn't include images
|
||||||
|
assigned_img.append(None)
|
||||||
|
n_assigned.append(None)
|
||||||
|
batched_assigned.append(assigned_img)
|
||||||
|
|
||||||
|
return batched_assigned
|
25
models/backbones/mistral.py
Normal file
25
models/backbones/mistral.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
device = "cuda" # the device to load the model onto
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is your favourite condiment?"},
|
||||||
|
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
|
||||||
|
{"role": "user", "content": "Do you have mayonnaise recipes?"}
|
||||||
|
]
|
||||||
|
p="Well, I'm quite partial to a good squeeze of fresh lemon juice."
|
||||||
|
encoded_input = tokenizer(p, return_tensors='pt')
|
||||||
|
embeds = model.model.embed_tokens(encoded_input.input_ids)
|
||||||
|
print(embeds.shape)
|
||||||
|
|
||||||
|
|
||||||
|
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
||||||
|
model_inputs = encodeds.to(device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
||||||
|
decoded = tokenizer.batch_decode(generated_ids)
|
||||||
|
print(decoded[0])
|
112
models/backbones/modeling_llama_v2.py
Normal file
112
models/backbones/modeling_llama_v2.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
reduction: Optional[str] = "mean",
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(reduction=reduction)
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
if reduction == "none":
|
||||||
|
loss = loss.view(logits.size(0), -1).mean(1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
112
models/backbones/modeling_llama_v3.py
Normal file
112
models/backbones/modeling_llama_v3.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
reduction: Optional[str] = "mean",
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(reduction=reduction)
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
if reduction == "none":
|
||||||
|
loss = loss.view(logits.size(0), -1).mean(1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
1388
models/backbones/modeling_mistral.py
Normal file
1388
models/backbones/modeling_mistral.py
Normal file
File diff suppressed because it is too large
Load diff
287
models/backbones/moes.py
Normal file
287
models/backbones/moes.py
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
drop=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=None,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||||
|
.permute(2, 0, 3, 1, 4)
|
||||||
|
)
|
||||||
|
q, k, v = (
|
||||||
|
qkv[0],
|
||||||
|
qkv[1],
|
||||||
|
qkv[2],
|
||||||
|
) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
if mask is not None:
|
||||||
|
# if mask.dim() != x.dim():
|
||||||
|
# expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
|
||||||
|
# else:
|
||||||
|
# expanded_mask = mask
|
||||||
|
mask = mask.bool()
|
||||||
|
attn = attn.masked_fill(~mask, float("-inf"))
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x, attn
|
||||||
|
|
||||||
|
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
expert_type,
|
||||||
|
use_sep_spatial_temp_experts=True,
|
||||||
|
has_hist=False,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=LlamaRMSNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.has_hist = has_hist
|
||||||
|
self.use_sep_spatial_temp_experts = use_sep_spatial_temp_experts
|
||||||
|
self.norm_att = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
|
||||||
|
if expert_type == 'modalities':
|
||||||
|
# EXPERT CONSTRUCTION
|
||||||
|
if use_sep_spatial_temp_experts:
|
||||||
|
# Spatial expert
|
||||||
|
self.norm_spatial = norm_layer(dim)
|
||||||
|
self.mlp_spatial = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Temporal expert
|
||||||
|
self.norm_temp = norm_layer(dim)
|
||||||
|
self.mlp_temp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vis expert
|
||||||
|
self.norm_vis = norm_layer(dim)
|
||||||
|
self.mlp_vis = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# caption expert
|
||||||
|
self.norm_cap = norm_layer(dim)
|
||||||
|
self.mlp_cap = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# history expert
|
||||||
|
if has_hist:
|
||||||
|
self.norm_hist = norm_layer(dim)
|
||||||
|
self.mlp_hist = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif expert_type == 'fusion':
|
||||||
|
# Fusion expert
|
||||||
|
self.norm_fusion = norm_layer(dim)
|
||||||
|
self.mlp_fusion = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len=None, is_vid=False, mask=None, only_text=False, expert_permutation=None):
|
||||||
|
|
||||||
|
if self.has_hist:
|
||||||
|
assert hist_feat_len is not None
|
||||||
|
|
||||||
|
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
|
||||||
|
x = x + self.drop_path(x_shortcut)
|
||||||
|
len_init = x.size(1)
|
||||||
|
# bs, h_dim = x.size(0), x.size(-1)
|
||||||
|
# device = x.device
|
||||||
|
# if only_text:
|
||||||
|
# # end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
|
||||||
|
# # x = x[:, special_toks_indices['<caption>']: end_idx_caption, :]
|
||||||
|
# x = x + self.drop_path(self.mlp_cap(self.norm_cap(x)))
|
||||||
|
|
||||||
|
if expert_flag == 'modalities':
|
||||||
|
if self.use_sep_spatial_temp_experts:
|
||||||
|
x_spatial = x[:, :vis_feat_len]
|
||||||
|
if expert_permutation is not None:
|
||||||
|
if expert_permutation['spatial'] == 'temporal':
|
||||||
|
x_spatial = x_spatial + self.drop_path(self.mlp_temp(self.norm_temp(x_spatial)))
|
||||||
|
elif expert_permutation['spatial'] == 'caption':
|
||||||
|
x_spatial = x_spatial + self.drop_path(self.mlp_cap(self.norm_cap(x_spatial)))
|
||||||
|
elif expert_permutation['spatial'] == 'history':
|
||||||
|
x_spatial = x_spatial + self.drop_path(self.mlp_hist(self.norm_hist(x_spatial)))
|
||||||
|
elif expert_permutation['spatial'] == 'spatial':
|
||||||
|
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
|
||||||
|
x_vis = x_spatial
|
||||||
|
|
||||||
|
else:
|
||||||
|
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
|
||||||
|
x_vis = x_spatial
|
||||||
|
|
||||||
|
if is_vid:
|
||||||
|
x_temporal = x[:, vis_feat_len:2*vis_feat_len]
|
||||||
|
if expert_permutation is not None:
|
||||||
|
if expert_permutation['temporal'] == 'spatial':
|
||||||
|
x_temporal = x_temporal + self.drop_path(self.mlp_spatial(self.norm_spatial(x_temporal)))
|
||||||
|
elif expert_permutation['temporal'] == 'caption':
|
||||||
|
x_temporal = x_temporal + self.drop_path(self.mlp_cap(self.norm_cap(x_temporal)))
|
||||||
|
elif expert_permutation['temporal'] == 'history':
|
||||||
|
x_temporal = x_temporal + self.drop_path(self.mlp_hist(self.norm_hist(x_temporal)))
|
||||||
|
elif expert_permutation['temporal'] == 'temporal':
|
||||||
|
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
|
||||||
|
else:
|
||||||
|
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
|
||||||
|
x_vis = torch.concat([x_spatial, x_temporal], dim=1)
|
||||||
|
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
|
||||||
|
else:
|
||||||
|
x_vis = x[:, :vis_feat_len]
|
||||||
|
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
|
||||||
|
|
||||||
|
if self.has_hist:
|
||||||
|
x_caption = x[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
|
||||||
|
if expert_permutation is not None:
|
||||||
|
if expert_permutation['caption'] == 'spatial':
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_spatial(self.norm_spatial(x_caption)))
|
||||||
|
elif expert_permutation['caption'] == 'temporal':
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_temp(self.norm_temp(x_caption)))
|
||||||
|
elif expert_permutation['caption'] == 'history':
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_hist(self.norm_hist(x_caption)))
|
||||||
|
elif expert_permutation['caption'] == 'caption':
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||||
|
else:
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||||
|
|
||||||
|
|
||||||
|
x_history = x[:, -hist_feat_len:]
|
||||||
|
if expert_permutation is not None:
|
||||||
|
if expert_permutation['history'] == 'spatial':
|
||||||
|
x_history = x_history + self.drop_path(self.mlp_spatial(self.norm_spatial(x_history)))
|
||||||
|
elif expert_permutation['history'] == 'temporal':
|
||||||
|
x_history = x_history + self.drop_path(self.mlp_temp(self.norm_temp(x_history)))
|
||||||
|
elif expert_permutation['history'] == 'caption':
|
||||||
|
x_history = x_history + self.drop_path(self.mlp_cap(self.norm_cap(x_history)))
|
||||||
|
elif expert_permutation['history'] == 'history':
|
||||||
|
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
|
||||||
|
else:
|
||||||
|
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
|
||||||
|
# concat the features back
|
||||||
|
x = torch.cat([x_vis, x_caption, x_history], dim=1)
|
||||||
|
else:
|
||||||
|
x_caption = x[:, -cap_feat_len:]
|
||||||
|
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||||
|
x = torch.cat([x_vis, x_caption], dim=1)
|
||||||
|
|
||||||
|
assert x.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
|
||||||
|
x.size(1), len_init
|
||||||
|
)
|
||||||
|
|
||||||
|
elif expert_flag == 'fusion':
|
||||||
|
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super(Pooler, self).__init__()
|
||||||
|
|
||||||
|
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
pooled_states = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(pooled_states)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
234
models/backbones/moes_huggingface.py
Normal file
234
models/backbones/moes_huggingface.py
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
import warnings
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig
|
||||||
|
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(self, config, expert_type):
|
||||||
|
super(MoELayer, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.expert_type = expert_type
|
||||||
|
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||||
|
|
||||||
|
# Shared across all experts
|
||||||
|
self.attention = BertAttention(self.bert_config)
|
||||||
|
|
||||||
|
# One for each expert
|
||||||
|
if expert_type == 'modalities':
|
||||||
|
# Spatial expert
|
||||||
|
self.intermediate_spatial = BertIntermediate(self.bert_config)
|
||||||
|
self.output_spatial = BertOutput(self.bert_config)
|
||||||
|
|
||||||
|
# Temporal expert
|
||||||
|
self.intermediate_temporal = BertIntermediate(self.bert_config)
|
||||||
|
self.output_temporal = BertOutput(self.bert_config)
|
||||||
|
|
||||||
|
# Vis Expert
|
||||||
|
self.intermediate_vis = BertIntermediate(self.bert_config)
|
||||||
|
self.output_vis = BertOutput(self.bert_config)
|
||||||
|
|
||||||
|
# Caption Expert
|
||||||
|
self.intermediate_caption = BertIntermediate(self.bert_config)
|
||||||
|
self.output_caption = BertOutput(self.bert_config)
|
||||||
|
|
||||||
|
if config.stage != 'stage_1':
|
||||||
|
# History Expert
|
||||||
|
self.intermediate_history = BertIntermediate(self.bert_config)
|
||||||
|
self.output_history = BertOutput(self.bert_config)
|
||||||
|
|
||||||
|
# Fusion expert
|
||||||
|
elif expert_type == 'fusion':
|
||||||
|
self.intermediate_fusion = BertIntermediate(self.bert_config)
|
||||||
|
self.output_fusion = BertOutput(self.bert_config)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for _, m in dict(self.named_modules()).items():
|
||||||
|
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
m.bias.data.zero_()
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def get_extended_attention_mask(
|
||||||
|
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||||
|
input_shape (`Tuple[int]`):
|
||||||
|
The shape of the input to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||||
|
"""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.dtype
|
||||||
|
|
||||||
|
if not (attention_mask.dim() == 2 and self.bert_config.is_decoder):
|
||||||
|
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
||||||
|
if device is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||||
|
)
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
# if self.config.is_decoder:
|
||||||
|
# extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
||||||
|
# input_shape, attention_mask, device
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False):
|
||||||
|
|
||||||
|
input_shape = hidden_states.size()[:-1]
|
||||||
|
# dtype = mask.dtype
|
||||||
|
# device = mask.device
|
||||||
|
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32)
|
||||||
|
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
head_mask=None
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
# outputs = self_attention_outputs[1:]
|
||||||
|
|
||||||
|
len_init = attention_output.size(1)
|
||||||
|
# bs, h_dim = x.size(0), x.size(-1)
|
||||||
|
# device = x.device
|
||||||
|
|
||||||
|
|
||||||
|
if expert_flag == 'modalities':
|
||||||
|
if only_text:
|
||||||
|
intermediate_output = self.intermediate_caption(attention_output)
|
||||||
|
layer_output = self.output_caption(intermediate_output, attention_output)
|
||||||
|
else:
|
||||||
|
# split the input first into different parts/modalities
|
||||||
|
unchanged = attention_output[:, :special_toks_indices['<vis>'], :]
|
||||||
|
end_idx_spatial = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
|
||||||
|
attention_spatial = attention_output[:, special_toks_indices['<vis>']:end_idx_spatial, :]
|
||||||
|
|
||||||
|
end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
|
||||||
|
attention_caption = attention_output[:, special_toks_indices['<caption>']: end_idx_caption, :]
|
||||||
|
|
||||||
|
attention_temporal, attention_history = None, None
|
||||||
|
|
||||||
|
if '<temporal>' in special_toks_indices:
|
||||||
|
end_idx_temporal = special_toks_indices['<caption>']
|
||||||
|
attention_temporal = attention_output[:, special_toks_indices['<temporal>']:end_idx_temporal, :]
|
||||||
|
|
||||||
|
if '<history>' in special_toks_indices:
|
||||||
|
end_idx_history = special_toks_indices['</s>'] + 1
|
||||||
|
attention_history = attention_output[:, special_toks_indices['<history>']:end_idx_history, :]
|
||||||
|
|
||||||
|
# Expert activation
|
||||||
|
# 1- Spatial
|
||||||
|
intermediate_spatial = self.intermediate_spatial(attention_spatial)
|
||||||
|
output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial)
|
||||||
|
|
||||||
|
output_vis = output_sapatial
|
||||||
|
|
||||||
|
# 2- Temporal
|
||||||
|
if attention_temporal is not None:
|
||||||
|
intermediate_temporal = self.intermediate_temporal(attention_temporal)
|
||||||
|
output_temporal = self.output_temporal(intermediate_temporal, attention_temporal)
|
||||||
|
|
||||||
|
attention_vis = torch.concat([output_sapatial, output_temporal], dim=1)
|
||||||
|
intermediate_vis = self.intermediate_vis(attention_vis)
|
||||||
|
output_vis = self.output_vis(intermediate_vis, attention_vis)
|
||||||
|
|
||||||
|
# 3- Caption
|
||||||
|
intermediate_caption = self.intermediate_caption(attention_caption)
|
||||||
|
output_caption = self.output_caption(intermediate_caption, attention_caption)
|
||||||
|
|
||||||
|
# 4- History
|
||||||
|
if attention_history is not None:
|
||||||
|
intermediate_history = self.intermediate_history(attention_history)
|
||||||
|
output_history = self.output_history(intermediate_history, attention_history)
|
||||||
|
|
||||||
|
output_list = [unchanged, output_vis, output_caption]
|
||||||
|
|
||||||
|
if attention_history is not None:
|
||||||
|
output_list.append(output_history)
|
||||||
|
|
||||||
|
# Concat the features back
|
||||||
|
layer_output = torch.concat(output_list, dim=1)
|
||||||
|
assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
|
||||||
|
layer_output.size(1), len_init
|
||||||
|
)
|
||||||
|
|
||||||
|
elif expert_flag == 'fusion':
|
||||||
|
intermediate_output = self.intermediate_fusion(attention_output)
|
||||||
|
layer_output = self.output_fusion(intermediate_output, attention_output)
|
||||||
|
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class MoEPooler(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MoEPooler, self).__init__()
|
||||||
|
|
||||||
|
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||||
|
hidden_size = self.bert_config.hidden_size
|
||||||
|
|
||||||
|
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for _, m in dict(self.named_modules()).items():
|
||||||
|
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
m.bias.data.zero_()
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, idx):
|
||||||
|
pooled_states = hidden_states[:, idx]
|
||||||
|
pooled_output = self.dense(pooled_states)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
247
models/backbones/moes_original.py
Normal file
247
models/backbones/moes_original.py
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
drop=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=None,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||||
|
.permute(2, 0, 3, 1, 4)
|
||||||
|
)
|
||||||
|
q, k, v = (
|
||||||
|
qkv[0],
|
||||||
|
qkv[1],
|
||||||
|
qkv[2],
|
||||||
|
) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
if mask is not None:
|
||||||
|
if mask.dim() != x.dim():
|
||||||
|
expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
|
||||||
|
else:
|
||||||
|
expanded_mask = mask
|
||||||
|
expanded_mask = expanded_mask.bool()
|
||||||
|
attn = attn.masked_fill(~expanded_mask, float("-inf"))
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x, attn
|
||||||
|
|
||||||
|
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
act_layer=nn.SiLU,
|
||||||
|
norm_layer=LlamaRMSNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_att = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
# EXPERT CONSTRUCTION
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
|
||||||
|
|
||||||
|
# Spatial expert
|
||||||
|
self.norm_spatial = norm_layer(dim)
|
||||||
|
self.mlp_spatial = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Temporal expert
|
||||||
|
self.norm_temp = norm_layer(dim)
|
||||||
|
self.mlp_temp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vis expert
|
||||||
|
self.norm_vis = norm_layer(dim)
|
||||||
|
self.mlp_vis = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# caption expert
|
||||||
|
self.norm_cap = norm_layer(dim)
|
||||||
|
self.mlp_cap = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# history expert
|
||||||
|
self.norm_hist = norm_layer(dim)
|
||||||
|
self.mlp_hist = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fusion expert
|
||||||
|
self.norm_fusion = norm_layer(dim)
|
||||||
|
self.mlp_fusion = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# expert_flag:{Only Text : 00 , Only Image : 01, Fusion : 10, Text & Image : 11} (BINARY)
|
||||||
|
|
||||||
|
# expert_flag:
|
||||||
|
# 0:
|
||||||
|
|
||||||
|
def forward(self, x, special_toks_indices, expert_flag, mask=None):
|
||||||
|
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
|
||||||
|
x = x + self.drop_path(x_shortcut)
|
||||||
|
bs, h_dim = x.size(0), x.size(-1)
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if expert_flag == 'modalities':
|
||||||
|
end_index = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
|
||||||
|
spatial_feats = x[:, special_toks_indices['<spatial>']: end_index, :]
|
||||||
|
spatial_feats = spatial_feats + self.drop_path(self.mlp_spatial(self.norm_spatial(spatial_feats)))
|
||||||
|
spatial_index = torch.arange(special_toks_indices['<spatial>'], end_index, device=device)
|
||||||
|
spatial_index = spatial_index.unsqueeze(0).unsqueeze(-1)
|
||||||
|
spatial_index = spatial_index.repeat(bs, 1, h_dim)
|
||||||
|
x = x.scatter(1, spatial_index, spatial_feats)
|
||||||
|
# x[:, special_toks_indices['<spatial>']: special_toks_indices['<temporal>'], :] = spatial_feats
|
||||||
|
|
||||||
|
end_index = special_toks_indices.get('<history>', special_toks_indices['</s>'])
|
||||||
|
caption_feats = x[:, special_toks_indices['<caption>']: end_index, :]
|
||||||
|
caption_feats = caption_feats + self.drop_path(self.mlp_cap(self.norm_cap(caption_feats)))
|
||||||
|
caption_index = torch.arange(special_toks_indices['<caption>'], end_index, device=device)
|
||||||
|
caption_index = caption_index.unsqueeze(0).unsqueeze(-1)
|
||||||
|
caption_index = caption_index.repeat(bs, 1, h_dim)
|
||||||
|
x = x.scatter(1, caption_index, caption_feats)
|
||||||
|
|
||||||
|
# x[:, special_toks_indices['<caption>']: special_toks_indices['</s>'], :] = caption_feats
|
||||||
|
|
||||||
|
if '<temporal>' in special_toks_indices:
|
||||||
|
temporal_feats = x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :]
|
||||||
|
temporal_feats = temporal_feats + self.drop_path(self.mlp_temp(self.norm_temp(temporal_feats)))
|
||||||
|
temporal_index = torch.arange(special_toks_indices['<temporal>'], special_toks_indices['<caption>'], device=device)
|
||||||
|
temporal_index = temporal_index.unsqueeze(0).unsqueeze(-1)
|
||||||
|
temporal_index = temporal_index.repeat(bs, 1, h_dim)
|
||||||
|
x = x.scatter(1, temporal_index, temporal_feats)
|
||||||
|
|
||||||
|
# x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :] = temporal_feats
|
||||||
|
|
||||||
|
vis_feats = x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :]
|
||||||
|
vis_feats = vis_feats + self.drop_path(self.mlp_vis(self.norm_vis(vis_feats)))
|
||||||
|
vis_index = torch.arange(special_toks_indices['<vis>'], special_toks_indices['<caption>'], device=device)
|
||||||
|
vis_index = vis_index.unsqueeze(0).unsqueeze(-1)
|
||||||
|
vis_index = vis_index.repeat(bs, 1, h_dim)
|
||||||
|
x = x.scatter(1, vis_index, vis_feats)
|
||||||
|
|
||||||
|
# x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :] = vis_feats
|
||||||
|
|
||||||
|
if '<history>' in special_toks_indices:
|
||||||
|
history_feats = x[:, special_toks_indices['<history>']: special_toks_indices['</s>'], :]
|
||||||
|
history_feats = history_feats + self.drop_path(self.mlp_hist(self.norm_hist(history_feats)))
|
||||||
|
history_index = torch.arange(special_toks_indices['<history>'], special_toks_indices['</s>'], device=device)
|
||||||
|
history_index = history_index.unsqueeze(0).unsqueeze(-1)
|
||||||
|
history_index = history_index.repeat(bs, 1, h_dim)
|
||||||
|
x = x.scatter(1, history_index, history_feats)
|
||||||
|
|
||||||
|
elif expert_flag == 'fusion':
|
||||||
|
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
|
||||||
|
|
||||||
|
return x, attn
|
||||||
|
|
||||||
|
# if expert_flag == 2:
|
||||||
|
# x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
# elif expert_flag == 0:
|
||||||
|
# x = (x[:, -it_split:])
|
||||||
|
# x = x + self.drop_path(self.sentence_mlp(self.sentence_norm(x)))
|
||||||
|
# elif expert_flag == 1:
|
||||||
|
# x = (x[:, :-it_split ])
|
||||||
|
# x = x + self.drop_path(self.image_mlp(self.image_norm(x)))
|
||||||
|
# elif expert_flag == 3:
|
||||||
|
# text, image = (x[:, :it_split], x[:, it_split:],)
|
||||||
|
# text = text + self.drop_path(self.sentence_mlp(self.sentence_norm(text)))
|
||||||
|
# image = image + self.drop_path(self.image_mlp(self.image_norm(image)))
|
||||||
|
# x = torch.cat([text, image], dim=1)
|
||||||
|
# elif expert_flag == 4:
|
||||||
|
# x = x + self.drop_path(self.generation_mlp(self.generation_norm(x)))
|
||||||
|
# return x, attn
|
0
models/common/__init__.py
Executable file
0
models/common/__init__.py
Executable file
474
models/common/config.py
Executable file
474
models/common/config.py
Executable file
|
@ -0,0 +1,474 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.config = {}
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
# Register the config and configuration for setup
|
||||||
|
registry.register("configuration", self)
|
||||||
|
|
||||||
|
user_config = self._build_opt_list(self.args.options)
|
||||||
|
|
||||||
|
config = OmegaConf.load(self.args.cfg_path)
|
||||||
|
|
||||||
|
runner_config = self.build_runner_config(config)
|
||||||
|
model_config = self.build_model_config(config, **user_config)
|
||||||
|
dataset_config = self.build_dataset_config(config)
|
||||||
|
|
||||||
|
# Validate the user-provided runner configuration
|
||||||
|
# model and dataset configuration are supposed to be validated by the respective classes
|
||||||
|
# [TODO] validate the model/dataset configuration
|
||||||
|
# self._validate_runner_config(runner_config)
|
||||||
|
|
||||||
|
# Override the default configuration with user options.
|
||||||
|
self.config = OmegaConf.merge(
|
||||||
|
runner_config, model_config, dataset_config, user_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_runner_config(self, runner_config):
|
||||||
|
"""
|
||||||
|
This method validates the configuration, such that
|
||||||
|
1) all the user specified options are valid;
|
||||||
|
2) no type mismatches between the user specified options and the config.
|
||||||
|
"""
|
||||||
|
runner_config_validator = create_runner_config_validator()
|
||||||
|
runner_config_validator.validate(runner_config)
|
||||||
|
|
||||||
|
def _build_opt_list(self, opts):
|
||||||
|
opts_dot_list = self._convert_to_dot_list(opts)
|
||||||
|
return OmegaConf.from_dotlist(opts_dot_list)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_model_config(config, **kwargs):
|
||||||
|
model = config.get("model", None)
|
||||||
|
assert model is not None, "Missing model configuration file."
|
||||||
|
|
||||||
|
model_cls = registry.get_model_class(model.arch)
|
||||||
|
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
||||||
|
|
||||||
|
model_type = kwargs.get("model.model_type", None)
|
||||||
|
if not model_type:
|
||||||
|
model_type = model.get("model_type", None)
|
||||||
|
# else use the model type selected by user.
|
||||||
|
|
||||||
|
assert model_type is not None, "Missing model_type."
|
||||||
|
|
||||||
|
print("--------------")
|
||||||
|
print("model arch",model.arch)
|
||||||
|
print("model cls",model_cls)
|
||||||
|
|
||||||
|
model_config_path = model_cls.default_config_path(model_type=model_type)
|
||||||
|
|
||||||
|
model_config = OmegaConf.create()
|
||||||
|
# hierarchy override, customized config > default config
|
||||||
|
model_config = OmegaConf.merge(
|
||||||
|
model_config,
|
||||||
|
OmegaConf.load(model_config_path),
|
||||||
|
{"model": config["model"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_runner_config(config):
|
||||||
|
return {"run": config.run}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_dataset_config(config):
|
||||||
|
datasets = config.get("datasets", None)
|
||||||
|
if datasets is None:
|
||||||
|
raise KeyError(
|
||||||
|
"Expecting 'datasets' as the root key for dataset configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_config = OmegaConf.create()
|
||||||
|
|
||||||
|
for dataset_name in datasets:
|
||||||
|
|
||||||
|
print("dataset name", dataset_name)
|
||||||
|
builder_cls = registry.get_builder_class(dataset_name)
|
||||||
|
|
||||||
|
dataset_config_type = datasets[dataset_name].get("type", "default")
|
||||||
|
dataset_config_path = builder_cls.default_config_path(
|
||||||
|
type=dataset_config_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# hierarchy override, customized config > default config
|
||||||
|
dataset_config = OmegaConf.merge(
|
||||||
|
dataset_config,
|
||||||
|
OmegaConf.load(dataset_config_path),
|
||||||
|
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_config
|
||||||
|
|
||||||
|
def _convert_to_dot_list(self, opts):
|
||||||
|
if opts is None:
|
||||||
|
opts = []
|
||||||
|
|
||||||
|
if len(opts) == 0:
|
||||||
|
return opts
|
||||||
|
|
||||||
|
has_equal = opts[0].find("=") != -1
|
||||||
|
|
||||||
|
if has_equal:
|
||||||
|
return opts
|
||||||
|
|
||||||
|
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_cfg(self):
|
||||||
|
return self.config.run
|
||||||
|
|
||||||
|
@property
|
||||||
|
def datasets_cfg(self):
|
||||||
|
return self.config.datasets
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_cfg(self):
|
||||||
|
return self.config.model
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
logging.info("\n===== Running Parameters =====")
|
||||||
|
logging.info(self._convert_node_to_json(self.config.run))
|
||||||
|
|
||||||
|
logging.info("\n====== Dataset Attributes ======")
|
||||||
|
datasets = self.config.datasets
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
if dataset in self.config.datasets:
|
||||||
|
logging.info(f"\n======== {dataset} =======")
|
||||||
|
dataset_config = self.config.datasets[dataset]
|
||||||
|
logging.info(self._convert_node_to_json(dataset_config))
|
||||||
|
else:
|
||||||
|
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
||||||
|
|
||||||
|
logging.info(f"\n====== Model Attributes ======")
|
||||||
|
logging.info(self._convert_node_to_json(self.config.model))
|
||||||
|
|
||||||
|
def _convert_node_to_json(self, node):
|
||||||
|
container = OmegaConf.to_container(node, resolve=True)
|
||||||
|
return json.dumps(container, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return OmegaConf.to_container(self.config)
|
||||||
|
|
||||||
|
|
||||||
|
def node_to_dict(node):
|
||||||
|
return OmegaConf.to_container(node)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigValidator:
|
||||||
|
"""
|
||||||
|
This is a preliminary implementation to centralize and validate the configuration.
|
||||||
|
May be altered in the future.
|
||||||
|
|
||||||
|
A helper class to validate configurations from yaml file.
|
||||||
|
|
||||||
|
This serves the following purposes:
|
||||||
|
1. Ensure all the options in the yaml are defined, raise error if not.
|
||||||
|
2. when type mismatches are found, the validator will raise an error.
|
||||||
|
3. a central place to store and display helpful messages for supported configurations.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
class _Argument:
|
||||||
|
def __init__(self, name, choices=None, type=None, help=None):
|
||||||
|
self.name = name
|
||||||
|
self.val = None
|
||||||
|
self.choices = choices
|
||||||
|
self.type = type
|
||||||
|
self.help = help
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = f"{self.name}={self.val}"
|
||||||
|
if self.type is not None:
|
||||||
|
s += f", ({self.type})"
|
||||||
|
if self.choices is not None:
|
||||||
|
s += f", choices: {self.choices}"
|
||||||
|
if self.help is not None:
|
||||||
|
s += f", ({self.help})"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __init__(self, description):
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
self.arguments = dict()
|
||||||
|
|
||||||
|
self.parsed_args = None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
assert self.parsed_args is not None, "No arguments parsed yet."
|
||||||
|
|
||||||
|
return self.parsed_args[key]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.format_help()
|
||||||
|
|
||||||
|
def add_argument(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Assume the first argument is the name of the argument.
|
||||||
|
"""
|
||||||
|
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
||||||
|
|
||||||
|
def validate(self, config=None):
|
||||||
|
"""
|
||||||
|
Convert yaml config (dict-like) to list, required by argparse.
|
||||||
|
"""
|
||||||
|
for k, v in config.items():
|
||||||
|
assert (
|
||||||
|
k in self.arguments
|
||||||
|
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
||||||
|
|
||||||
|
if self.arguments[k].type is not None:
|
||||||
|
try:
|
||||||
|
self.arguments[k].val = self.arguments[k].type(v)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
||||||
|
|
||||||
|
if self.arguments[k].choices is not None:
|
||||||
|
assert (
|
||||||
|
v in self.arguments[k].choices
|
||||||
|
), f"""{k} must be one of {self.arguments[k].choices}."""
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def format_arguments(self):
|
||||||
|
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
||||||
|
|
||||||
|
def format_help(self):
|
||||||
|
# description + key-value pair string for each argument
|
||||||
|
help_msg = str(self.description)
|
||||||
|
return help_msg + ", available arguments: " + self.format_arguments()
|
||||||
|
|
||||||
|
def print_help(self):
|
||||||
|
# display help message
|
||||||
|
print(self.format_help())
|
||||||
|
|
||||||
|
|
||||||
|
def create_runner_config_validator():
|
||||||
|
validator = ConfigValidator(description="Runner configurations")
|
||||||
|
|
||||||
|
validator.add_argument(
|
||||||
|
"runner",
|
||||||
|
type=str,
|
||||||
|
choices=["runner_base", "runner_iter"],
|
||||||
|
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
||||||
|
runner runs based on iters. Default: runner_base""",
|
||||||
|
)
|
||||||
|
# add argumetns for training dataset ratios
|
||||||
|
validator.add_argument(
|
||||||
|
"train_dataset_ratios",
|
||||||
|
type=Dict[str, float],
|
||||||
|
help="""Ratios of training dataset. This is used in iteration-based runner.
|
||||||
|
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
||||||
|
Default: None""",
|
||||||
|
)
|
||||||
|
validator.add_argument(
|
||||||
|
"max_iters",
|
||||||
|
type=float,
|
||||||
|
help="Maximum number of iterations to run.",
|
||||||
|
)
|
||||||
|
validator.add_argument(
|
||||||
|
"max_epoch",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of epochs to run.",
|
||||||
|
)
|
||||||
|
# add arguments for iters_per_inner_epoch
|
||||||
|
validator.add_argument(
|
||||||
|
"iters_per_inner_epoch",
|
||||||
|
type=float,
|
||||||
|
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
||||||
|
)
|
||||||
|
lr_scheds_choices = registry.list_lr_schedulers()
|
||||||
|
validator.add_argument(
|
||||||
|
"lr_sched",
|
||||||
|
type=str,
|
||||||
|
choices=lr_scheds_choices,
|
||||||
|
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
||||||
|
)
|
||||||
|
task_choices = registry.list_tasks()
|
||||||
|
validator.add_argument(
|
||||||
|
"task",
|
||||||
|
type=str,
|
||||||
|
choices=task_choices,
|
||||||
|
help="Task to use, from {}".format(task_choices),
|
||||||
|
)
|
||||||
|
# add arguments for init_lr
|
||||||
|
validator.add_argument(
|
||||||
|
"init_lr",
|
||||||
|
type=float,
|
||||||
|
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
||||||
|
)
|
||||||
|
# add arguments for min_lr
|
||||||
|
validator.add_argument(
|
||||||
|
"min_lr",
|
||||||
|
type=float,
|
||||||
|
help="Minimum learning rate (after decay).",
|
||||||
|
)
|
||||||
|
# add arguments for warmup_lr
|
||||||
|
validator.add_argument(
|
||||||
|
"warmup_lr",
|
||||||
|
type=float,
|
||||||
|
help="Starting learning rate for warmup.",
|
||||||
|
)
|
||||||
|
# add arguments for learning rate decay rate
|
||||||
|
validator.add_argument(
|
||||||
|
"lr_decay_rate",
|
||||||
|
type=float,
|
||||||
|
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
||||||
|
)
|
||||||
|
# add arguments for weight decay
|
||||||
|
validator.add_argument(
|
||||||
|
"weight_decay",
|
||||||
|
type=float,
|
||||||
|
help="Weight decay rate.",
|
||||||
|
)
|
||||||
|
# add arguments for training batch size
|
||||||
|
validator.add_argument(
|
||||||
|
"batch_size_train",
|
||||||
|
type=int,
|
||||||
|
help="Training batch size.",
|
||||||
|
)
|
||||||
|
# add arguments for evaluation batch size
|
||||||
|
validator.add_argument(
|
||||||
|
"batch_size_eval",
|
||||||
|
type=int,
|
||||||
|
help="Evaluation batch size, including validation and testing.",
|
||||||
|
)
|
||||||
|
# add arguments for number of workers for data loading
|
||||||
|
validator.add_argument(
|
||||||
|
"num_workers",
|
||||||
|
help="Number of workers for data loading.",
|
||||||
|
)
|
||||||
|
# add arguments for warm up steps
|
||||||
|
validator.add_argument(
|
||||||
|
"warmup_steps",
|
||||||
|
type=int,
|
||||||
|
help="Number of warmup steps. Required if a warmup schedule is used.",
|
||||||
|
)
|
||||||
|
# add arguments for random seed
|
||||||
|
validator.add_argument(
|
||||||
|
"seed",
|
||||||
|
type=int,
|
||||||
|
help="Random seed.",
|
||||||
|
)
|
||||||
|
# add arguments for output directory
|
||||||
|
validator.add_argument(
|
||||||
|
"output_dir",
|
||||||
|
type=str,
|
||||||
|
help="Output directory to save checkpoints and logs.",
|
||||||
|
)
|
||||||
|
# add arguments for whether only use evaluation
|
||||||
|
validator.add_argument(
|
||||||
|
"evaluate",
|
||||||
|
help="Whether to only evaluate the model. If true, training will not be performed.",
|
||||||
|
)
|
||||||
|
# add arguments for splits used for training, e.g. ["train", "val"]
|
||||||
|
validator.add_argument(
|
||||||
|
"train_splits",
|
||||||
|
type=list,
|
||||||
|
help="Splits to use for training.",
|
||||||
|
)
|
||||||
|
# add arguments for splits used for validation, e.g. ["val"]
|
||||||
|
validator.add_argument(
|
||||||
|
"valid_splits",
|
||||||
|
type=list,
|
||||||
|
help="Splits to use for validation. If not provided, will skip the validation.",
|
||||||
|
)
|
||||||
|
# add arguments for splits used for testing, e.g. ["test"]
|
||||||
|
validator.add_argument(
|
||||||
|
"test_splits",
|
||||||
|
type=list,
|
||||||
|
help="Splits to use for testing. If not provided, will skip the testing.",
|
||||||
|
)
|
||||||
|
# add arguments for accumulating gradient for iterations
|
||||||
|
validator.add_argument(
|
||||||
|
"accum_grad_iters",
|
||||||
|
type=int,
|
||||||
|
help="Number of iterations to accumulate gradient for.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ====== distributed training ======
|
||||||
|
validator.add_argument(
|
||||||
|
"device",
|
||||||
|
type=str,
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
||||||
|
)
|
||||||
|
validator.add_argument(
|
||||||
|
"world_size",
|
||||||
|
type=int,
|
||||||
|
help="Number of processes participating in the job.",
|
||||||
|
)
|
||||||
|
validator.add_argument("dist_url", type=str)
|
||||||
|
validator.add_argument("distributed", type=bool)
|
||||||
|
# add arguments to opt using distributed sampler during evaluation or not
|
||||||
|
validator.add_argument(
|
||||||
|
"use_dist_eval_sampler",
|
||||||
|
type=bool,
|
||||||
|
help="Whether to use distributed sampler during evaluation or not.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ====== task specific ======
|
||||||
|
# generation task specific arguments
|
||||||
|
# add arguments for maximal length of text output
|
||||||
|
validator.add_argument(
|
||||||
|
"max_len",
|
||||||
|
type=int,
|
||||||
|
help="Maximal length of text output.",
|
||||||
|
)
|
||||||
|
# add arguments for minimal length of text output
|
||||||
|
validator.add_argument(
|
||||||
|
"min_len",
|
||||||
|
type=int,
|
||||||
|
help="Minimal length of text output.",
|
||||||
|
)
|
||||||
|
# add arguments number of beams
|
||||||
|
validator.add_argument(
|
||||||
|
"num_beams",
|
||||||
|
type=int,
|
||||||
|
help="Number of beams used for beam search.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# vqa task specific arguments
|
||||||
|
# add arguments for number of answer candidates
|
||||||
|
validator.add_argument(
|
||||||
|
"num_ans_candidates",
|
||||||
|
type=int,
|
||||||
|
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
||||||
|
)
|
||||||
|
# add arguments for inference method
|
||||||
|
validator.add_argument(
|
||||||
|
"inference_method",
|
||||||
|
type=str,
|
||||||
|
choices=["genearte", "rank"],
|
||||||
|
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ====== model specific ======
|
||||||
|
validator.add_argument(
|
||||||
|
"k_test",
|
||||||
|
type=int,
|
||||||
|
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return validator
|
203
models/common/dist_utils.py
Executable file
203
models/common/dist_utils.py
Executable file
|
@ -0,0 +1,203 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import timm.models.hub as timm_hub
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if args.distributed is False:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.rank = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
|
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
elif "SLURM_PROCID" in os.environ:
|
||||||
|
args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
args.rank = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print(
|
||||||
|
"| distributed init (rank {}, world {}): {}".format(
|
||||||
|
args.rank, args.world_size, args.dist_url
|
||||||
|
),
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend,
|
||||||
|
init_method=args.dist_url,
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank,
|
||||||
|
timeout=datetime.timedelta(
|
||||||
|
days=365
|
||||||
|
), # allow auto-downloading and de-compressing
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_info():
|
||||||
|
if torch.__version__ < "1.0":
|
||||||
|
initialized = dist._initialized
|
||||||
|
else:
|
||||||
|
initialized = dist.is_initialized()
|
||||||
|
if initialized:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
else: # non-distributed training
|
||||||
|
rank = 0
|
||||||
|
world_size = 1
|
||||||
|
return rank, world_size
|
||||||
|
|
||||||
|
|
||||||
|
def main_process(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
rank, _ = get_dist_info()
|
||||||
|
if rank == 0:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def download_cached_file(url, check_hash=True, progress=False):
|
||||||
|
"""
|
||||||
|
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
||||||
|
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_cached_file_path():
|
||||||
|
# a hack to sync the file path across processes
|
||||||
|
parts = torch.hub.urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
||||||
|
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
timm_hub.download_cached_file(url, check_hash, progress)
|
||||||
|
|
||||||
|
if is_dist_avail_and_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return get_cached_file_path()
|
||||||
|
|
||||||
|
|
||||||
|
class GatherLayer(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gather tensors from all workers with support for backward propagation:
|
||||||
|
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
output = [
|
||||||
|
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(output, x)
|
||||||
|
return tuple(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grads):
|
||||||
|
all_gradients = torch.stack(grads)
|
||||||
|
torch.distributed.all_reduce(all_gradients)
|
||||||
|
return all_gradients[torch.distributed.get_rank()]
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_with_grad(tensors):
|
||||||
|
"""
|
||||||
|
Performs all_gather operation on the provided tensors.
|
||||||
|
Graph remains connected for backward grad computation.
|
||||||
|
"""
|
||||||
|
# Queue the gathered tensors
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
# There is no need for reduction in the single-proc case
|
||||||
|
if world_size == 1:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
# tensor_all = GatherLayer.apply(tensors)
|
||||||
|
tensor_all = GatherLayer.apply(tensors)
|
||||||
|
|
||||||
|
return torch.cat(tensor_all, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def concat_all_gather(tensor):
|
||||||
|
"""
|
||||||
|
Performs all_gather operation on the provided tensors.
|
||||||
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||||
|
"""
|
||||||
|
# if use distributed training
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
tensors_gather = [
|
||||||
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||||
|
|
||||||
|
output = torch.cat(tensors_gather, dim=0)
|
||||||
|
return output
|
224
models/common/eval_utils.py
Normal file
224
models/common/eval_utils.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu
|
||||||
|
import sys
|
||||||
|
sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img')
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
|
||||||
|
# imports modules for registration
|
||||||
|
from minigpt4.datasets.builders import *
|
||||||
|
from minigpt4.models import *
|
||||||
|
from minigpt4.processors import *
|
||||||
|
# from minigpt4.runners import *
|
||||||
|
from minigpt4.tasks import *
|
||||||
|
from pycocoevalcap.cider.cider import Cider
|
||||||
|
import os
|
||||||
|
import openai
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
import ast
|
||||||
|
import time
|
||||||
|
|
||||||
|
def eval_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Demo")
|
||||||
|
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml")
|
||||||
|
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint")
|
||||||
|
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
|
||||||
|
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
||||||
|
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
|
||||||
|
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_texts(texts, conv_temp, template='<Img><ImageHere></Img>', lengths=None):
|
||||||
|
convs = [conv_temp.copy() for _ in range(len(texts))]
|
||||||
|
if lengths is None:
|
||||||
|
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)]
|
||||||
|
else:
|
||||||
|
templates = [template * length for length in lengths]
|
||||||
|
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)]
|
||||||
|
[conv.append_message(conv.roles[1], None) for conv in convs]
|
||||||
|
texts = [conv.get_prompt() for conv in convs]
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(args):
|
||||||
|
print('Initialization Model')
|
||||||
|
cfg = Config(args)
|
||||||
|
cfg.model_cfg.ckpt = args.ckpt
|
||||||
|
cfg.model_cfg.lora_r = args.lora_r
|
||||||
|
cfg.model_cfg.lora_alpha = args.lora_alpha
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_config.low_resource = True
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to('cuda:0')
|
||||||
|
|
||||||
|
# import pudb; pudb.set_trace()
|
||||||
|
key = list(cfg.datasets_cfg.keys())[0]
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
|
||||||
|
print(vis_processor_cfg)
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
print('Initialization Finished')
|
||||||
|
return model, vis_processor
|
||||||
|
|
||||||
|
def computeIoU(bbox1, bbox2):
|
||||||
|
x1, y1, x2, y2 = bbox1
|
||||||
|
x3, y3, x4, y4 = bbox2
|
||||||
|
intersection_x1 = max(x1, x3)
|
||||||
|
intersection_y1 = max(y1, y3)
|
||||||
|
intersection_x2 = min(x2, x4)
|
||||||
|
intersection_y2 = min(y2, y4)
|
||||||
|
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
||||||
|
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
||||||
|
union_area = bbox1_area + bbox2_area - intersection_area
|
||||||
|
iou = intersection_area / union_area
|
||||||
|
return iou
|
||||||
|
|
||||||
|
def eval_bleu(results):
|
||||||
|
bleus1,bleus2,bleus3,bleus4 = [],[],[],[]
|
||||||
|
for result in tqdm (results,desc="bleu_eval"):
|
||||||
|
gt = result['gt']
|
||||||
|
pred = result['pred']
|
||||||
|
bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0)))
|
||||||
|
bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0)))
|
||||||
|
bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0)))
|
||||||
|
bleus4.append(sentence_bleu([gt.split()], pred.split()))
|
||||||
|
# print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True)
|
||||||
|
return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)}
|
||||||
|
|
||||||
|
# Create a Cider object
|
||||||
|
cider_scorer = Cider()
|
||||||
|
def eval_cider(pred_result,gt_result):
|
||||||
|
# Compute CIDEr scores
|
||||||
|
mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result)
|
||||||
|
cider_scores_dict={}
|
||||||
|
for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") :
|
||||||
|
assert pred_vid_id==gt_vid_id
|
||||||
|
cider_scores_dict[pred_vid_id] = score
|
||||||
|
return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict}
|
||||||
|
|
||||||
|
|
||||||
|
openai.api_key_path = "/home/ataallka/chatgpt_api.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def chat_gpt_eval(results,output_path):
|
||||||
|
trial=0
|
||||||
|
gpt_results=[]
|
||||||
|
avg_chatgpt_score=0
|
||||||
|
existed_files={}
|
||||||
|
# read previous results from output path
|
||||||
|
for file in os.listdir(output_path):
|
||||||
|
if file.endswith(".json"):
|
||||||
|
with open(f'{output_path}/{file}') as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
|
gpt_results.append(data[0])
|
||||||
|
avg_chatgpt_score+=float(data[0]['chatgpt_score'])
|
||||||
|
existed_files[data[0]['video_name']]=True
|
||||||
|
length_output_path=len(os.listdir(output_path))
|
||||||
|
while len (results)!= length_output_path:
|
||||||
|
for res in tqdm(results,desc="chatgpt_eval"):
|
||||||
|
if existed_files.get(res['video_name'],False):
|
||||||
|
continue
|
||||||
|
video_name=res['video_name']
|
||||||
|
sentence_1=res['A']
|
||||||
|
sentence_2=res['pred']
|
||||||
|
try:
|
||||||
|
# prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
|
||||||
|
prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
res['chatgpt_score']=response.choices[0].message['content']
|
||||||
|
out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']}
|
||||||
|
gpt_results.append(out)
|
||||||
|
# save each video result in a json file
|
||||||
|
with open(f'{output_path}/{video_name}.json', 'w') as f:
|
||||||
|
json.dump([out], f)
|
||||||
|
avg_chatgpt_score+=float(response.choices[0].message['content'])
|
||||||
|
except Exception as e:
|
||||||
|
print("chat gpt error",e)
|
||||||
|
print ("Finished chat gpt evaluation in trial",trial)
|
||||||
|
trial+=1
|
||||||
|
length_output_path=len(os.listdir(output_path))
|
||||||
|
return results,avg_chatgpt_score/len(results)
|
||||||
|
def GPT4_answer(question, answer,pred):
|
||||||
|
try:
|
||||||
|
# Compute the correctness score
|
||||||
|
completion = openai.ChatCompletion.create(
|
||||||
|
# model="gpt-3.5-turbo",
|
||||||
|
model='gpt-4',
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content":
|
||||||
|
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
|
||||||
|
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
|
||||||
|
"------"
|
||||||
|
"##INSTRUCTIONS: "
|
||||||
|
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
|
||||||
|
"- Consider synonyms or paraphrases as valid matches.\n"
|
||||||
|
"- Evaluate the correctness of the prediction compared to the answer."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content":
|
||||||
|
"Please evaluate the following video-based question-answer pair:\n\n"
|
||||||
|
f"Question: {question}\n"
|
||||||
|
f"Correct Answer: {answer}\n"
|
||||||
|
f"Predicted Answer: {pred}\n\n"
|
||||||
|
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
|
||||||
|
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
|
||||||
|
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
||||||
|
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Convert response to a Python dictionary.
|
||||||
|
response_message = completion["choices"][0]["message"]["content"]
|
||||||
|
response_dict = ast.literal_eval(response_message)
|
||||||
|
return response_dict
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error : {e}")
|
||||||
|
return None
|
||||||
|
def GPT4_evaluation(val_result):
|
||||||
|
scores=[]
|
||||||
|
yes_count=0
|
||||||
|
no_count=0
|
||||||
|
for res in val_result:
|
||||||
|
gpt_response=GPT4_answer(res['Q'],res['A'],res['pred'])
|
||||||
|
if gpt_response is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
scores.append(float(gpt_response['score']))
|
||||||
|
if 'yes' in gpt_response['pred'].lower():
|
||||||
|
yes_count+=1
|
||||||
|
elif 'no' in gpt_response['pred'].lower():
|
||||||
|
no_count+=1
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
avg_score=sum(scores)/len(scores)
|
||||||
|
accuracy=(yes_count/(yes_count+no_count))*100
|
||||||
|
print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
|
||||||
|
return avg_score,accuracy
|
||||||
|
|
||||||
|
# with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f:
|
||||||
|
# results = json.load(f)
|
||||||
|
# t1=time.time()
|
||||||
|
# avg_score,accuracy=GPT4_evaluation(results)
|
||||||
|
# print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
|
||||||
|
# print(f"Time taken: {time.time()-t1}")
|
24
models/common/gradcam.py
Executable file
24
models/common/gradcam.py
Executable file
|
@ -0,0 +1,24 @@
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from scipy.ndimage import filters
|
||||||
|
from skimage import transform as skimage_transform
|
||||||
|
|
||||||
|
|
||||||
|
def getAttMap(img, attMap, blur=True, overlap=True):
|
||||||
|
attMap -= attMap.min()
|
||||||
|
if attMap.max() > 0:
|
||||||
|
attMap /= attMap.max()
|
||||||
|
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
||||||
|
if blur:
|
||||||
|
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||||
|
attMap -= attMap.min()
|
||||||
|
attMap /= attMap.max()
|
||||||
|
cmap = plt.get_cmap("jet")
|
||||||
|
attMapV = cmap(attMap)
|
||||||
|
attMapV = np.delete(attMapV, 3, 2)
|
||||||
|
if overlap:
|
||||||
|
attMap = (
|
||||||
|
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||||
|
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||||
|
)
|
||||||
|
return attMap
|
195
models/common/logger.py
Executable file
195
models/common/logger.py
Executable file
|
@ -0,0 +1,195 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from minigpt4.common import dist_utils
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
if not dist_utils.is_dist_avail_and_initialized():
|
||||||
|
return
|
||||||
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median,
|
||||||
|
avg=self.avg,
|
||||||
|
global_avg=self.global_avg,
|
||||||
|
max=self.max,
|
||||||
|
value=self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger(object):
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
assert isinstance(v, (float, int))
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError(
|
||||||
|
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append("{}: {}".format(name, str(meter)))
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def global_avg(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
log_msg = [
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg.append("max mem: {memory:.0f}")
|
||||||
|
log_msg = self.delimiter.join(log_msg)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(
|
||||||
|
"{} Total time: {} ({:.4f} s / it)".format(
|
||||||
|
header, total_time_str, total_time / len(iterable)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger():
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[logging.StreamHandler()],
|
||||||
|
)
|
119
models/common/optims.py
Executable file
119
models/common/optims.py
Executable file
|
@ -0,0 +1,119 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
||||||
|
class LinearWarmupStepLRScheduler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
max_epoch,
|
||||||
|
min_lr,
|
||||||
|
init_lr,
|
||||||
|
decay_rate=1,
|
||||||
|
warmup_start_lr=-1,
|
||||||
|
warmup_steps=0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
self.max_epoch = max_epoch
|
||||||
|
self.min_lr = min_lr
|
||||||
|
|
||||||
|
self.decay_rate = decay_rate
|
||||||
|
|
||||||
|
self.init_lr = init_lr
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||||
|
|
||||||
|
def step(self, cur_epoch, cur_step):
|
||||||
|
if cur_epoch == 0:
|
||||||
|
warmup_lr_schedule(
|
||||||
|
step=cur_step,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
max_step=self.warmup_steps,
|
||||||
|
init_lr=self.warmup_start_lr,
|
||||||
|
max_lr=self.init_lr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
step_lr_schedule(
|
||||||
|
epoch=cur_epoch,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
init_lr=self.init_lr,
|
||||||
|
min_lr=self.min_lr,
|
||||||
|
decay_rate=self.decay_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
||||||
|
class LinearWarmupCosineLRScheduler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
max_epoch,
|
||||||
|
iters_per_epoch,
|
||||||
|
min_lr,
|
||||||
|
init_lr,
|
||||||
|
warmup_steps=0,
|
||||||
|
warmup_start_lr=-1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
self.max_epoch = max_epoch
|
||||||
|
self.iters_per_epoch = iters_per_epoch
|
||||||
|
self.min_lr = min_lr
|
||||||
|
|
||||||
|
self.init_lr = init_lr
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||||
|
|
||||||
|
def step(self, cur_epoch, cur_step):
|
||||||
|
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
||||||
|
if total_cur_step < self.warmup_steps:
|
||||||
|
warmup_lr_schedule(
|
||||||
|
step=total_cur_step,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
max_step=self.warmup_steps,
|
||||||
|
init_lr=self.warmup_start_lr,
|
||||||
|
max_lr=self.init_lr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cosine_lr_schedule(
|
||||||
|
epoch=total_cur_step,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
max_epoch=self.max_epoch * self.iters_per_epoch,
|
||||||
|
init_lr=self.init_lr,
|
||||||
|
min_lr=self.min_lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
||||||
|
"""Decay the learning rate"""
|
||||||
|
lr = (init_lr - min_lr) * 0.5 * (
|
||||||
|
1.0 + math.cos(math.pi * epoch / max_epoch)
|
||||||
|
) + min_lr
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
||||||
|
"""Warmup the learning rate"""
|
||||||
|
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
|
||||||
|
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
||||||
|
"""Decay the learning rate"""
|
||||||
|
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = lr
|
330
models/common/registry.py
Executable file
330
models/common/registry.py
Executable file
|
@ -0,0 +1,330 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
mapping = {
|
||||||
|
"builder_name_mapping": {},
|
||||||
|
"task_name_mapping": {},
|
||||||
|
"processor_name_mapping": {},
|
||||||
|
"model_name_mapping": {},
|
||||||
|
"lr_scheduler_name_mapping": {},
|
||||||
|
"runner_name_mapping": {},
|
||||||
|
"state": {},
|
||||||
|
"paths": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_builder(cls, name):
|
||||||
|
r"""Register a dataset builder to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the builder will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(builder_cls):
|
||||||
|
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||||
|
|
||||||
|
assert issubclass(
|
||||||
|
builder_cls, BaseDatasetBuilder
|
||||||
|
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||||
|
builder_cls
|
||||||
|
)
|
||||||
|
if name in cls.mapping["builder_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["builder_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["builder_name_mapping"][name] = builder_cls
|
||||||
|
return builder_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_task(cls, name):
|
||||||
|
r"""Register a task to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the task will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(task_cls):
|
||||||
|
from minigpt4.tasks.base_task import BaseTask
|
||||||
|
|
||||||
|
assert issubclass(
|
||||||
|
task_cls, BaseTask
|
||||||
|
), "All tasks must inherit BaseTask class"
|
||||||
|
if name in cls.mapping["task_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["task_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["task_name_mapping"][name] = task_cls
|
||||||
|
return task_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_model(cls, name):
|
||||||
|
r"""Register a task to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the task will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(model_cls):
|
||||||
|
# from minigpt4.models import BaseModel
|
||||||
|
|
||||||
|
# assert issubclass(
|
||||||
|
# model_cls, BaseModel
|
||||||
|
# ), "All models must inherit BaseModel class"
|
||||||
|
|
||||||
|
if name in cls.mapping["model_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["model_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["model_name_mapping"][name] = model_cls
|
||||||
|
return model_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_processor(cls, name):
|
||||||
|
r"""Register a processor to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the task will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(processor_cls):
|
||||||
|
from minigpt4.processors import BaseProcessor
|
||||||
|
|
||||||
|
assert issubclass(
|
||||||
|
processor_cls, BaseProcessor
|
||||||
|
), "All processors must inherit BaseProcessor class"
|
||||||
|
if name in cls.mapping["processor_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["processor_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["processor_name_mapping"][name] = processor_cls
|
||||||
|
return processor_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_lr_scheduler(cls, name):
|
||||||
|
r"""Register a model to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the task will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(lr_sched_cls):
|
||||||
|
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
||||||
|
return lr_sched_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_runner(cls, name):
|
||||||
|
r"""Register a model to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the task will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(runner_cls):
|
||||||
|
if name in cls.mapping["runner_name_mapping"]:
|
||||||
|
raise KeyError(
|
||||||
|
"Name '{}' already registered for {}.".format(
|
||||||
|
name, cls.mapping["runner_name_mapping"][name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.mapping["runner_name_mapping"][name] = runner_cls
|
||||||
|
return runner_cls
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_path(cls, name, path):
|
||||||
|
r"""Register a path to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the path will be registered.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
"""
|
||||||
|
assert isinstance(path, str), "All path must be str."
|
||||||
|
if name in cls.mapping["paths"]:
|
||||||
|
raise KeyError("Name '{}' already registered.".format(name))
|
||||||
|
cls.mapping["paths"][name] = path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, obj):
|
||||||
|
r"""Register an item to registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key with which the item will be registered.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
|
registry.register("config", {})
|
||||||
|
"""
|
||||||
|
path = name.split(".")
|
||||||
|
current = cls.mapping["state"]
|
||||||
|
|
||||||
|
for part in path[:-1]:
|
||||||
|
if part not in current:
|
||||||
|
current[part] = {}
|
||||||
|
current = current[part]
|
||||||
|
|
||||||
|
current[path[-1]] = obj
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# def get_trainer_class(cls, name):
|
||||||
|
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_builder_class(cls, name):
|
||||||
|
return cls.mapping["builder_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_class(cls, name):
|
||||||
|
return cls.mapping["model_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_task_class(cls, name):
|
||||||
|
return cls.mapping["task_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_processor_class(cls, name):
|
||||||
|
return cls.mapping["processor_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lr_scheduler_class(cls, name):
|
||||||
|
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_runner_class(cls, name):
|
||||||
|
return cls.mapping["runner_name_mapping"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_runners(cls):
|
||||||
|
return sorted(cls.mapping["runner_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_models(cls):
|
||||||
|
return sorted(cls.mapping["model_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_tasks(cls):
|
||||||
|
return sorted(cls.mapping["task_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_processors(cls):
|
||||||
|
return sorted(cls.mapping["processor_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_lr_schedulers(cls):
|
||||||
|
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_datasets(cls):
|
||||||
|
return sorted(cls.mapping["builder_name_mapping"].keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_path(cls, name):
|
||||||
|
return cls.mapping["paths"].get(name, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, name, default=None, no_warning=False):
|
||||||
|
r"""Get an item from registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (string): Key whose value needs to be retrieved.
|
||||||
|
default: If passed and key is not in registry, default value will
|
||||||
|
be returned with a warning. Default: None
|
||||||
|
no_warning (bool): If passed as True, warning when key doesn't exist
|
||||||
|
will not be generated. Useful for MMF's
|
||||||
|
internal operations. Default: False
|
||||||
|
"""
|
||||||
|
original_name = name
|
||||||
|
name = name.split(".")
|
||||||
|
value = cls.mapping["state"]
|
||||||
|
for subname in name:
|
||||||
|
value = value.get(subname, default)
|
||||||
|
if value is default:
|
||||||
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
"writer" in cls.mapping["state"]
|
||||||
|
and value == default
|
||||||
|
and no_warning is False
|
||||||
|
):
|
||||||
|
cls.mapping["state"]["writer"].warning(
|
||||||
|
"Key {} is not present in registry, returning default value "
|
||||||
|
"of {}".format(original_name, default)
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister(cls, name):
|
||||||
|
r"""Remove an item from registry with key 'name'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Key which needs to be removed.
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from mmf.common.registry import registry
|
||||||
|
|
||||||
|
config = registry.unregister("config")
|
||||||
|
"""
|
||||||
|
return cls.mapping["state"].pop(name, None)
|
||||||
|
|
||||||
|
|
||||||
|
registry = Registry()
|
424
models/common/utils.py
Executable file
424
models/common/utils.py
Executable file
|
@ -0,0 +1,424 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import urllib
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from iopath.common.download import download
|
||||||
|
from iopath.common.file_io import file_lock, g_pathmgr
|
||||||
|
from models.common.registry import registry
|
||||||
|
from torch.utils.model_zoo import tqdm
|
||||||
|
from torchvision.datasets.utils import (
|
||||||
|
check_integrity,
|
||||||
|
download_file_from_google_drive,
|
||||||
|
extract_archive,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
return datetime.now().strftime("%Y%m%d%H%M")
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_path(rel_path):
|
||||||
|
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_path(rel_path):
|
||||||
|
return os.path.join(registry.get_path("library_root"), rel_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(filename):
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# The following are adapted from torchvision and vissl
|
||||||
|
# torchvision: https://github.com/pytorch/vision
|
||||||
|
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
||||||
|
|
||||||
|
|
||||||
|
def makedir(dir_path):
|
||||||
|
"""
|
||||||
|
Create the directory if it does not exist.
|
||||||
|
"""
|
||||||
|
is_success = False
|
||||||
|
try:
|
||||||
|
if not g_pathmgr.exists(dir_path):
|
||||||
|
g_pathmgr.mkdirs(dir_path)
|
||||||
|
is_success = True
|
||||||
|
except BaseException:
|
||||||
|
print(f"Error creating directory: {dir_path}")
|
||||||
|
return is_success
|
||||||
|
|
||||||
|
|
||||||
|
def get_redirected_url(url: str):
|
||||||
|
"""
|
||||||
|
Given a URL, returns the URL it redirects to or the
|
||||||
|
original URL in case of no indirection
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||||
|
if response.history:
|
||||||
|
return response.url
|
||||||
|
else:
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def to_google_drive_download_url(view_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Utility function to transform a view URL of google drive
|
||||||
|
to a download URL for google drive
|
||||||
|
Example input:
|
||||||
|
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
||||||
|
Example output:
|
||||||
|
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
||||||
|
"""
|
||||||
|
splits = view_url.split("/")
|
||||||
|
assert splits[-1] == "view"
|
||||||
|
file_id = splits[-2]
|
||||||
|
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
||||||
|
"""
|
||||||
|
Download a file from google drive
|
||||||
|
Downloading an URL from google drive requires confirmation when
|
||||||
|
the file of the size is too big (google drive notifies that
|
||||||
|
anti-viral checks cannot be performed on such files)
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
|
||||||
|
# First get the confirmation token and append it to the URL
|
||||||
|
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||||
|
for k, v in response.cookies.items():
|
||||||
|
if k.startswith("download_warning"):
|
||||||
|
url = url + "&confirm=" + v
|
||||||
|
|
||||||
|
# Then download the content of the file
|
||||||
|
with session.get(url, stream=True, verify=True) as response:
|
||||||
|
makedir(output_path)
|
||||||
|
path = os.path.join(output_path, output_file_name)
|
||||||
|
total_size = int(response.headers.get("Content-length", 0))
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
with tqdm(total=total_size) as progress_bar:
|
||||||
|
for block in response.iter_content(
|
||||||
|
chunk_size=io.DEFAULT_BUFFER_SIZE
|
||||||
|
):
|
||||||
|
file.write(block)
|
||||||
|
progress_bar.update(len(block))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||||||
|
parts = urlparse(url)
|
||||||
|
|
||||||
|
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||||||
|
if match is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return match.group("id")
|
||||||
|
|
||||||
|
|
||||||
|
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
||||||
|
with open(filename, "wb") as fh:
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
||||||
|
) as response:
|
||||||
|
with tqdm(total=response.length) as pbar:
|
||||||
|
for chunk in iter(lambda: response.read(chunk_size), ""):
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
pbar.update(chunk_size)
|
||||||
|
fh.write(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(
|
||||||
|
url: str,
|
||||||
|
root: str,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
md5: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Download a file from a url and place it in root.
|
||||||
|
Args:
|
||||||
|
url (str): URL to download file from
|
||||||
|
root (str): Directory to place downloaded file in
|
||||||
|
filename (str, optional): Name to save the file under.
|
||||||
|
If None, use the basename of the URL.
|
||||||
|
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||||
|
"""
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
if not filename:
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
fpath = os.path.join(root, filename)
|
||||||
|
|
||||||
|
makedir(root)
|
||||||
|
|
||||||
|
# check if file is already present locally
|
||||||
|
if check_integrity(fpath, md5):
|
||||||
|
print("Using downloaded and verified file: " + fpath)
|
||||||
|
return
|
||||||
|
|
||||||
|
# expand redirect chain if needed
|
||||||
|
url = get_redirected_url(url)
|
||||||
|
|
||||||
|
# check if file is located on Google Drive
|
||||||
|
file_id = _get_google_drive_file_id(url)
|
||||||
|
if file_id is not None:
|
||||||
|
return download_file_from_google_drive(file_id, root, filename, md5)
|
||||||
|
|
||||||
|
# download the file
|
||||||
|
try:
|
||||||
|
print("Downloading " + url + " to " + fpath)
|
||||||
|
_urlretrieve(url, fpath)
|
||||||
|
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
||||||
|
if url[:5] == "https":
|
||||||
|
url = url.replace("https:", "http:")
|
||||||
|
print(
|
||||||
|
"Failed download. Trying https -> http instead."
|
||||||
|
" Downloading " + url + " to " + fpath
|
||||||
|
)
|
||||||
|
_urlretrieve(url, fpath)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# check integrity of downloaded file
|
||||||
|
if not check_integrity(fpath, md5):
|
||||||
|
raise RuntimeError("File not found or corrupted.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_archive(
|
||||||
|
url: str,
|
||||||
|
download_root: str,
|
||||||
|
extract_root: Optional[str] = None,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
md5: Optional[str] = None,
|
||||||
|
remove_finished: bool = False,
|
||||||
|
) -> None:
|
||||||
|
download_root = os.path.expanduser(download_root)
|
||||||
|
if extract_root is None:
|
||||||
|
extract_root = download_root
|
||||||
|
if not filename:
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
|
download_url(url, download_root, filename, md5)
|
||||||
|
|
||||||
|
archive = os.path.join(download_root, filename)
|
||||||
|
print("Extracting {} to {}".format(archive, extract_root))
|
||||||
|
extract_archive(archive, extract_root, remove_finished)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_url(url: str, cache_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
This implementation downloads the remote resource and caches it locally.
|
||||||
|
The resource will only be downloaded if not previously requested.
|
||||||
|
"""
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
||||||
|
makedir(dirname)
|
||||||
|
filename = url.split("/")[-1]
|
||||||
|
cached = os.path.join(dirname, filename)
|
||||||
|
with file_lock(cached):
|
||||||
|
if not os.path.isfile(cached):
|
||||||
|
logging.info(f"Downloading {url} to {cached} ...")
|
||||||
|
cached = download(url, dirname, filename=filename)
|
||||||
|
logging.info(f"URL {url} cached in {cached}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (prigoyal): convert this into RAII-style API
|
||||||
|
def create_file_symlink(file1, file2):
|
||||||
|
"""
|
||||||
|
Simply create the symlinks for a given file1 to file2.
|
||||||
|
Useful during model checkpointing to symlinks to the
|
||||||
|
latest successful checkpoint.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if g_pathmgr.exists(file2):
|
||||||
|
g_pathmgr.rm(file2)
|
||||||
|
g_pathmgr.symlink(file1, file2)
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Could NOT create symlink. Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_file(data, filename, append_to_json=True, verbose=True):
|
||||||
|
"""
|
||||||
|
Common i/o utility to handle saving data to various file formats.
|
||||||
|
Supported:
|
||||||
|
.pkl, .pickle, .npy, .json
|
||||||
|
Specifically for .json, users have the option to either append (default)
|
||||||
|
or rewrite by passing in Boolean value to append_to_json.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logging.info(f"Saving data to file: {filename}")
|
||||||
|
file_ext = os.path.splitext(filename)[1]
|
||||||
|
if file_ext in [".pkl", ".pickle"]:
|
||||||
|
with g_pathmgr.open(filename, "wb") as fopen:
|
||||||
|
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
||||||
|
elif file_ext == ".npy":
|
||||||
|
with g_pathmgr.open(filename, "wb") as fopen:
|
||||||
|
np.save(fopen, data)
|
||||||
|
elif file_ext == ".json":
|
||||||
|
if append_to_json:
|
||||||
|
with g_pathmgr.open(filename, "a") as fopen:
|
||||||
|
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||||
|
fopen.flush()
|
||||||
|
else:
|
||||||
|
with g_pathmgr.open(filename, "w") as fopen:
|
||||||
|
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||||
|
fopen.flush()
|
||||||
|
elif file_ext == ".yaml":
|
||||||
|
with g_pathmgr.open(filename, "w") as fopen:
|
||||||
|
dump = yaml.dump(data)
|
||||||
|
fopen.write(dump)
|
||||||
|
fopen.flush()
|
||||||
|
else:
|
||||||
|
raise Exception(f"Saving {file_ext} is not supported yet")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logging.info(f"Saved data to file: {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
||||||
|
"""
|
||||||
|
Common i/o utility to handle loading data from various file formats.
|
||||||
|
Supported:
|
||||||
|
.pkl, .pickle, .npy, .json
|
||||||
|
For the npy files, we support reading the files in mmap_mode.
|
||||||
|
If the mmap_mode of reading is not successful, we load data without the
|
||||||
|
mmap_mode.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logging.info(f"Loading data from file: {filename}")
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(filename)[1]
|
||||||
|
if file_ext == ".txt":
|
||||||
|
with g_pathmgr.open(filename, "r") as fopen:
|
||||||
|
data = fopen.readlines()
|
||||||
|
elif file_ext in [".pkl", ".pickle"]:
|
||||||
|
with g_pathmgr.open(filename, "rb") as fopen:
|
||||||
|
data = pickle.load(fopen, encoding="latin1")
|
||||||
|
elif file_ext == ".npy":
|
||||||
|
if mmap_mode:
|
||||||
|
try:
|
||||||
|
with g_pathmgr.open(filename, "rb") as fopen:
|
||||||
|
data = np.load(
|
||||||
|
fopen,
|
||||||
|
allow_pickle=allow_pickle,
|
||||||
|
encoding="latin1",
|
||||||
|
mmap_mode=mmap_mode,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logging.info(
|
||||||
|
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
||||||
|
)
|
||||||
|
data = np.load(
|
||||||
|
filename,
|
||||||
|
allow_pickle=allow_pickle,
|
||||||
|
encoding="latin1",
|
||||||
|
mmap_mode=mmap_mode,
|
||||||
|
)
|
||||||
|
logging.info("Successfully loaded without g_pathmgr")
|
||||||
|
except Exception:
|
||||||
|
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
||||||
|
with g_pathmgr.open(filename, "rb") as fopen:
|
||||||
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||||
|
else:
|
||||||
|
with g_pathmgr.open(filename, "rb") as fopen:
|
||||||
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||||
|
elif file_ext == ".json":
|
||||||
|
with g_pathmgr.open(filename, "r") as fopen:
|
||||||
|
data = json.load(fopen)
|
||||||
|
elif file_ext == ".yaml":
|
||||||
|
with g_pathmgr.open(filename, "r") as fopen:
|
||||||
|
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
||||||
|
elif file_ext == ".csv":
|
||||||
|
with g_pathmgr.open(filename, "r") as fopen:
|
||||||
|
data = pd.read_csv(fopen)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Reading from {file_ext} is not supported yet")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def abspath(resource_path: str):
|
||||||
|
"""
|
||||||
|
Make a path absolute, but take into account prefixes like
|
||||||
|
"http://" or "manifold://"
|
||||||
|
"""
|
||||||
|
regex = re.compile(r"^\w+://")
|
||||||
|
if regex.match(resource_path) is None:
|
||||||
|
return os.path.abspath(resource_path)
|
||||||
|
else:
|
||||||
|
return resource_path
|
||||||
|
|
||||||
|
|
||||||
|
def makedir(dir_path):
|
||||||
|
"""
|
||||||
|
Create the directory if it does not exist.
|
||||||
|
"""
|
||||||
|
is_success = False
|
||||||
|
try:
|
||||||
|
if not g_pathmgr.exists(dir_path):
|
||||||
|
g_pathmgr.mkdirs(dir_path)
|
||||||
|
is_success = True
|
||||||
|
except BaseException:
|
||||||
|
logging.info(f"Error creating directory: {dir_path}")
|
||||||
|
return is_success
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(input_url):
|
||||||
|
"""
|
||||||
|
Check if an input string is a url. look for http(s):// and ignoring the case
|
||||||
|
"""
|
||||||
|
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
||||||
|
return is_url
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_dir(dir):
|
||||||
|
"""
|
||||||
|
Utility for deleting a directory. Useful for cleaning the storage space
|
||||||
|
that contains various training artifacts like checkpoints, data etc.
|
||||||
|
"""
|
||||||
|
if os.path.exists(dir):
|
||||||
|
logging.info(f"Deleting directory: {dir}")
|
||||||
|
shutil.rmtree(dir)
|
||||||
|
logging.info(f"Deleted contents of directory: {dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_size(filename):
|
||||||
|
"""
|
||||||
|
Given a file, get the size of file in MB
|
||||||
|
"""
|
||||||
|
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
||||||
|
return size_in_mb
|
|
@ -0,0 +1,89 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import sys
|
||||||
|
dataDir = '../../VQA'
|
||||||
|
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
|
||||||
|
from vqa import VQA
|
||||||
|
from vqaEvaluation.vqaEval import VQAEval
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import skimage.io as io
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
|
# set up file names and paths
|
||||||
|
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
||||||
|
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
||||||
|
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
||||||
|
dataSubType ='train2014'
|
||||||
|
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
||||||
|
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
||||||
|
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
||||||
|
resultType ='fake'
|
||||||
|
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
|
||||||
|
|
||||||
|
# An example result json file has been provided in './Results' folder.
|
||||||
|
|
||||||
|
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
|
||||||
|
resultType, fileType) for fileType in fileTypes]
|
||||||
|
|
||||||
|
# create vqa object and vqaRes object
|
||||||
|
vqa = VQA(annFile, quesFile)
|
||||||
|
vqaRes = vqa.loadRes(resFile, quesFile)
|
||||||
|
|
||||||
|
# create vqaEval object by taking vqa and vqaRes
|
||||||
|
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
|
||||||
|
|
||||||
|
# evaluate results
|
||||||
|
"""
|
||||||
|
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
||||||
|
By default it uses all the question ids in annotation file
|
||||||
|
"""
|
||||||
|
vqaEval.evaluate()
|
||||||
|
|
||||||
|
# print accuracies
|
||||||
|
print "\n"
|
||||||
|
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
|
||||||
|
print "Per Question Type Accuracy is the following:"
|
||||||
|
for quesType in vqaEval.accuracy['perQuestionType']:
|
||||||
|
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
|
||||||
|
print "\n"
|
||||||
|
print "Per Answer Type Accuracy is the following:"
|
||||||
|
for ansType in vqaEval.accuracy['perAnswerType']:
|
||||||
|
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
|
||||||
|
print "\n"
|
||||||
|
# demo how to use evalQA to retrieve low score result
|
||||||
|
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
|
||||||
|
if len(evals) > 0:
|
||||||
|
print 'ground truth answers'
|
||||||
|
randomEval = random.choice(evals)
|
||||||
|
randomAnn = vqa.loadQA(randomEval)
|
||||||
|
vqa.showQA(randomAnn)
|
||||||
|
|
||||||
|
print '\n'
|
||||||
|
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
|
||||||
|
ann = vqaRes.loadQA(randomEval)[0]
|
||||||
|
print "Answer: %s\n" %(ann['answer'])
|
||||||
|
|
||||||
|
imgId = randomAnn[0]['image_id']
|
||||||
|
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
||||||
|
if os.path.isfile(imgDir + imgFilename):
|
||||||
|
I = io.imread(imgDir + imgFilename)
|
||||||
|
plt.imshow(I)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# plot accuracy for various question types
|
||||||
|
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
|
||||||
|
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
|
||||||
|
plt.title('Per Question Type Accuracy', fontsize=10)
|
||||||
|
plt.xlabel('Question Types', fontsize=10)
|
||||||
|
plt.ylabel('Accuracy', fontsize=10)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# save evaluation results to ./Results folder
|
||||||
|
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
|
||||||
|
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
|
||||||
|
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
|
||||||
|
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
author='aagrawal'
|
|
@ -0,0 +1,192 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
__author__='aagrawal'
|
||||||
|
|
||||||
|
import re
|
||||||
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||||||
|
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class VQAEval:
|
||||||
|
def __init__(self, vqa, vqaRes, n=2):
|
||||||
|
self.n = n
|
||||||
|
self.accuracy = {}
|
||||||
|
self.evalQA = {}
|
||||||
|
self.evalQuesType = {}
|
||||||
|
self.evalAnsType = {}
|
||||||
|
self.vqa = vqa
|
||||||
|
self.vqaRes = vqaRes
|
||||||
|
self.params = {'question_id': vqa.getQuesIds()}
|
||||||
|
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
|
||||||
|
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
|
||||||
|
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
|
||||||
|
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
|
||||||
|
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
|
||||||
|
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
|
||||||
|
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
|
||||||
|
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
|
||||||
|
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
|
||||||
|
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
|
||||||
|
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
|
||||||
|
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
|
||||||
|
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
|
||||||
|
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
|
||||||
|
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
|
||||||
|
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
|
||||||
|
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
|
||||||
|
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
|
||||||
|
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
|
||||||
|
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
|
||||||
|
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
|
||||||
|
"youll": "you'll", "youre": "you're", "youve": "you've"}
|
||||||
|
self.manualMap = { 'none': '0',
|
||||||
|
'zero': '0',
|
||||||
|
'one': '1',
|
||||||
|
'two': '2',
|
||||||
|
'three': '3',
|
||||||
|
'four': '4',
|
||||||
|
'five': '5',
|
||||||
|
'six': '6',
|
||||||
|
'seven': '7',
|
||||||
|
'eight': '8',
|
||||||
|
'nine': '9',
|
||||||
|
'ten': '10'
|
||||||
|
}
|
||||||
|
self.articles = ['a',
|
||||||
|
'an',
|
||||||
|
'the'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
||||||
|
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
||||||
|
self.punct = [';', r"/", '[', ']', '"', '{', '}',
|
||||||
|
'(', ')', '=', '+', '\\', '_', '-',
|
||||||
|
'>', '<', '@', '`', ',', '?', '!']
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self, quesIds=None):
|
||||||
|
if quesIds == None:
|
||||||
|
quesIds = [quesId for quesId in self.params['question_id']]
|
||||||
|
gts = {}
|
||||||
|
res = {}
|
||||||
|
for quesId in quesIds:
|
||||||
|
gts[quesId] = self.vqa.qa[quesId]
|
||||||
|
res[quesId] = self.vqaRes.qa[quesId]
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# Compute accuracy
|
||||||
|
# =================================================
|
||||||
|
accQA = []
|
||||||
|
accQuesType = {}
|
||||||
|
accAnsType = {}
|
||||||
|
# print "computing accuracy"
|
||||||
|
step = 0
|
||||||
|
for quesId in quesIds:
|
||||||
|
for ansDic in gts[quesId]['answers']:
|
||||||
|
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
|
||||||
|
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
|
||||||
|
ansDic['answer'] = ansDic['answer'].strip()
|
||||||
|
resAns = res[quesId]['answer']
|
||||||
|
resAns = resAns.replace('\n', ' ')
|
||||||
|
resAns = resAns.replace('\t', ' ')
|
||||||
|
resAns = resAns.strip()
|
||||||
|
gtAcc = []
|
||||||
|
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
|
||||||
|
|
||||||
|
if len(set(gtAnswers)) > 1:
|
||||||
|
for ansDic in gts[quesId]['answers']:
|
||||||
|
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
|
||||||
|
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
|
||||||
|
resAns = self.processPunctuation(resAns)
|
||||||
|
resAns = self.processDigitArticle(resAns)
|
||||||
|
|
||||||
|
for gtAnsDatum in gts[quesId]['answers']:
|
||||||
|
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
|
||||||
|
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
|
||||||
|
acc = min(1, float(len(matchingAns))/3)
|
||||||
|
gtAcc.append(acc)
|
||||||
|
quesType = gts[quesId]['question_type']
|
||||||
|
ansType = gts[quesId]['answer_type']
|
||||||
|
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
|
||||||
|
accQA.append(avgGTAcc)
|
||||||
|
if quesType not in accQuesType:
|
||||||
|
accQuesType[quesType] = []
|
||||||
|
accQuesType[quesType].append(avgGTAcc)
|
||||||
|
if ansType not in accAnsType:
|
||||||
|
accAnsType[ansType] = []
|
||||||
|
accAnsType[ansType].append(avgGTAcc)
|
||||||
|
self.setEvalQA(quesId, avgGTAcc)
|
||||||
|
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
||||||
|
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
||||||
|
if step%100 == 0:
|
||||||
|
self.updateProgress(step/float(len(quesIds)))
|
||||||
|
step = step + 1
|
||||||
|
|
||||||
|
self.setAccuracy(accQA, accQuesType, accAnsType)
|
||||||
|
# print "Done computing accuracy"
|
||||||
|
|
||||||
|
def processPunctuation(self, inText):
|
||||||
|
outText = inText
|
||||||
|
for p in self.punct:
|
||||||
|
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
|
||||||
|
outText = outText.replace(p, '')
|
||||||
|
else:
|
||||||
|
outText = outText.replace(p, ' ')
|
||||||
|
outText = self.periodStrip.sub("",
|
||||||
|
outText,
|
||||||
|
re.UNICODE)
|
||||||
|
return outText
|
||||||
|
|
||||||
|
def processDigitArticle(self, inText):
|
||||||
|
outText = []
|
||||||
|
tempText = inText.lower().split()
|
||||||
|
for word in tempText:
|
||||||
|
word = self.manualMap.setdefault(word, word)
|
||||||
|
if word not in self.articles:
|
||||||
|
outText.append(word)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
for wordId, word in enumerate(outText):
|
||||||
|
if word in self.contractions:
|
||||||
|
outText[wordId] = self.contractions[word]
|
||||||
|
outText = ' '.join(outText)
|
||||||
|
return outText
|
||||||
|
|
||||||
|
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
||||||
|
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
|
||||||
|
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
|
||||||
|
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
|
||||||
|
|
||||||
|
def setEvalQA(self, quesId, acc):
|
||||||
|
self.evalQA[quesId] = round(100*acc, self.n)
|
||||||
|
|
||||||
|
def setEvalQuesType(self, quesId, quesType, acc):
|
||||||
|
if quesType not in self.evalQuesType:
|
||||||
|
self.evalQuesType[quesType] = {}
|
||||||
|
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
|
||||||
|
|
||||||
|
def setEvalAnsType(self, quesId, ansType, acc):
|
||||||
|
if ansType not in self.evalAnsType:
|
||||||
|
self.evalAnsType[ansType] = {}
|
||||||
|
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
|
||||||
|
|
||||||
|
def updateProgress(self, progress):
|
||||||
|
barLength = 20
|
||||||
|
status = ""
|
||||||
|
if isinstance(progress, int):
|
||||||
|
progress = float(progress)
|
||||||
|
if not isinstance(progress, float):
|
||||||
|
progress = 0
|
||||||
|
status = "error: progress var must be float\r\n"
|
||||||
|
if progress < 0:
|
||||||
|
progress = 0
|
||||||
|
status = "Halt...\r\n"
|
||||||
|
if progress >= 1:
|
||||||
|
progress = 1
|
||||||
|
status = "Done...\r\n"
|
||||||
|
block = int(round(barLength*progress))
|
||||||
|
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
|
||||||
|
sys.stdout.write(text)
|
||||||
|
sys.stdout.flush()
|
73
models/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
Normal file
73
models/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
from vqaTools.vqa import VQA
|
||||||
|
import random
|
||||||
|
import skimage.io as io
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
dataDir ='../../VQA'
|
||||||
|
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
||||||
|
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
||||||
|
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
||||||
|
dataSubType ='train2014'
|
||||||
|
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
||||||
|
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
||||||
|
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
||||||
|
|
||||||
|
# initialize VQA api for QA annotations
|
||||||
|
vqa=VQA(annFile, quesFile)
|
||||||
|
|
||||||
|
# load and display QA annotations for given question types
|
||||||
|
"""
|
||||||
|
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
|
||||||
|
"""
|
||||||
|
annIds = vqa.getQuesIds(quesTypes='how many');
|
||||||
|
anns = vqa.loadQA(annIds)
|
||||||
|
randomAnn = random.choice(anns)
|
||||||
|
vqa.showQA([randomAnn])
|
||||||
|
imgId = randomAnn['image_id']
|
||||||
|
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
||||||
|
if os.path.isfile(imgDir + imgFilename):
|
||||||
|
I = io.imread(imgDir + imgFilename)
|
||||||
|
plt.imshow(I)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# load and display QA annotations for given answer types
|
||||||
|
"""
|
||||||
|
ansTypes can be one of the following
|
||||||
|
yes/no
|
||||||
|
number
|
||||||
|
other
|
||||||
|
"""
|
||||||
|
annIds = vqa.getQuesIds(ansTypes='yes/no');
|
||||||
|
anns = vqa.loadQA(annIds)
|
||||||
|
randomAnn = random.choice(anns)
|
||||||
|
vqa.showQA([randomAnn])
|
||||||
|
imgId = randomAnn['image_id']
|
||||||
|
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
||||||
|
if os.path.isfile(imgDir + imgFilename):
|
||||||
|
I = io.imread(imgDir + imgFilename)
|
||||||
|
plt.imshow(I)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# load and display QA annotations for given images
|
||||||
|
"""
|
||||||
|
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
|
||||||
|
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
|
||||||
|
"""
|
||||||
|
ids = vqa.getImgIds()
|
||||||
|
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
|
||||||
|
anns = vqa.loadQA(annIds)
|
||||||
|
randomAnn = random.choice(anns)
|
||||||
|
vqa.showQA([randomAnn])
|
||||||
|
imgId = randomAnn['image_id']
|
||||||
|
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
||||||
|
if os.path.isfile(imgDir + imgFilename):
|
||||||
|
I = io.imread(imgDir + imgFilename)
|
||||||
|
plt.imshow(I)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
__author__ = 'aagrawal'
|
179
models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
Normal file
179
models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
__author__ = 'aagrawal'
|
||||||
|
__version__ = '0.9'
|
||||||
|
|
||||||
|
# Interface for accessing the VQA dataset.
|
||||||
|
|
||||||
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||||||
|
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
||||||
|
|
||||||
|
# The following functions are defined:
|
||||||
|
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
||||||
|
# getQuesIds - Get question ids that satisfy given filter conditions.
|
||||||
|
# getImgIds - Get image ids that satisfy given filter conditions.
|
||||||
|
# loadQA - Load questions and answers with the specified question ids.
|
||||||
|
# showQA - Display the specified questions and answers.
|
||||||
|
# loadRes - Load result file and create result object.
|
||||||
|
|
||||||
|
# Help on each function can be accessed by: "help(COCO.function)"
|
||||||
|
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class VQA:
|
||||||
|
def __init__(self, annotation_file=None, question_file=None):
|
||||||
|
"""
|
||||||
|
Constructor of VQA helper class for reading and visualizing questions and answers.
|
||||||
|
:param annotation_file (str): location of VQA annotation file
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# load dataset
|
||||||
|
self.dataset = {}
|
||||||
|
self.questions = {}
|
||||||
|
self.qa = {}
|
||||||
|
self.qqa = {}
|
||||||
|
self.imgToQA = {}
|
||||||
|
if not annotation_file == None and not question_file == None:
|
||||||
|
# print 'loading VQA annotations and questions into memory...'
|
||||||
|
time_t = datetime.datetime.utcnow()
|
||||||
|
dataset = json.load(open(annotation_file, 'r'))
|
||||||
|
questions = json.load(open(question_file, 'r'))
|
||||||
|
# print datetime.datetime.utcnow() - time_t
|
||||||
|
self.dataset = dataset
|
||||||
|
self.questions = questions
|
||||||
|
self.createIndex()
|
||||||
|
|
||||||
|
def createIndex(self):
|
||||||
|
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
|
||||||
|
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
||||||
|
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
||||||
|
for ann in self.dataset['annotations']:
|
||||||
|
imgToQA[ann['image_id']] += [ann]
|
||||||
|
qa[ann['question_id']] = ann
|
||||||
|
for ques in self.questions['questions']:
|
||||||
|
qqa[ques['question_id']] = ques
|
||||||
|
# print 'index created!'
|
||||||
|
|
||||||
|
# create class members
|
||||||
|
self.qa = qa
|
||||||
|
self.qqa = qqa
|
||||||
|
self.imgToQA = imgToQA
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
"""
|
||||||
|
Print information about the VQA annotation file.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# for key, value in self.datset['info'].items():
|
||||||
|
# print '%s: %s'%(key, value)
|
||||||
|
|
||||||
|
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
||||||
|
"""
|
||||||
|
Get question ids that satisfy given filter conditions. default skips that filter
|
||||||
|
:param imgIds (int array) : get question ids for given imgs
|
||||||
|
quesTypes (str array) : get question ids for given question types
|
||||||
|
ansTypes (str array) : get question ids for given answer types
|
||||||
|
:return: ids (int array) : integer array of question ids
|
||||||
|
"""
|
||||||
|
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
||||||
|
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
||||||
|
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
||||||
|
|
||||||
|
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
||||||
|
anns = self.dataset['annotations']
|
||||||
|
else:
|
||||||
|
if not len(imgIds) == 0:
|
||||||
|
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
|
||||||
|
else:
|
||||||
|
anns = self.dataset['annotations']
|
||||||
|
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
||||||
|
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
||||||
|
ids = [ann['question_id'] for ann in anns]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
||||||
|
"""
|
||||||
|
Get image ids that satisfy given filter conditions. default skips that filter
|
||||||
|
:param quesIds (int array) : get image ids for given question ids
|
||||||
|
quesTypes (str array) : get image ids for given question types
|
||||||
|
ansTypes (str array) : get image ids for given answer types
|
||||||
|
:return: ids (int array) : integer array of image ids
|
||||||
|
"""
|
||||||
|
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
||||||
|
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
||||||
|
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
||||||
|
|
||||||
|
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
||||||
|
anns = self.dataset['annotations']
|
||||||
|
else:
|
||||||
|
if not len(quesIds) == 0:
|
||||||
|
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
|
||||||
|
else:
|
||||||
|
anns = self.dataset['annotations']
|
||||||
|
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
||||||
|
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
||||||
|
ids = [ann['image_id'] for ann in anns]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def loadQA(self, ids=[]):
|
||||||
|
"""
|
||||||
|
Load questions and answers with the specified question ids.
|
||||||
|
:param ids (int array) : integer ids specifying question ids
|
||||||
|
:return: qa (object array) : loaded qa objects
|
||||||
|
"""
|
||||||
|
if type(ids) == list:
|
||||||
|
return [self.qa[id] for id in ids]
|
||||||
|
elif type(ids) == int:
|
||||||
|
return [self.qa[ids]]
|
||||||
|
|
||||||
|
def showQA(self, anns):
|
||||||
|
"""
|
||||||
|
Display the specified annotations.
|
||||||
|
:param anns (array of object): annotations to display
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if len(anns) == 0:
|
||||||
|
return 0
|
||||||
|
for ann in anns:
|
||||||
|
quesId = ann['question_id']
|
||||||
|
print("Question: %s" % (self.qqa[quesId]['question']))
|
||||||
|
for ans in ann['answers']:
|
||||||
|
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
|
||||||
|
|
||||||
|
def loadRes(self, resFile, quesFile):
|
||||||
|
"""
|
||||||
|
Load result file and return a result object.
|
||||||
|
:param resFile (str) : file name of result file
|
||||||
|
:return: res (obj) : result api object
|
||||||
|
"""
|
||||||
|
res = VQA()
|
||||||
|
res.questions = json.load(open(quesFile))
|
||||||
|
res.dataset['info'] = copy.deepcopy(self.questions['info'])
|
||||||
|
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
|
||||||
|
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
|
||||||
|
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
|
||||||
|
res.dataset['license'] = copy.deepcopy(self.questions['license'])
|
||||||
|
|
||||||
|
# print 'Loading and preparing results... '
|
||||||
|
time_t = datetime.datetime.utcnow()
|
||||||
|
anns = json.load(open(resFile))
|
||||||
|
assert type(anns) == list, 'results is not an array of objects'
|
||||||
|
annsQuesIds = [ann['question_id'] for ann in anns]
|
||||||
|
assert set(annsQuesIds) == set(self.getQuesIds()), \
|
||||||
|
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
||||||
|
for ann in anns:
|
||||||
|
quesId = ann['question_id']
|
||||||
|
if res.dataset['task_type'] == 'Multiple Choice':
|
||||||
|
assert ann['answer'] in self.qqa[quesId][
|
||||||
|
'multiple_choices'], 'predicted answer is not one of the multiple choices'
|
||||||
|
qaAnn = self.qa[quesId]
|
||||||
|
ann['image_id'] = qaAnn['image_id']
|
||||||
|
ann['question_type'] = qaAnn['question_type']
|
||||||
|
ann['answer_type'] = qaAnn['answer_type']
|
||||||
|
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
|
||||||
|
|
||||||
|
res.dataset['annotations'] = anns
|
||||||
|
res.createIndex()
|
||||||
|
return res
|
80
models/common/vqa_tools/VQA/README.md
Normal file
80
models/common/vqa_tools/VQA/README.md
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
|
||||||
|
===================
|
||||||
|
## VQA v2.0 release ##
|
||||||
|
This release consists of
|
||||||
|
- Real
|
||||||
|
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
||||||
|
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
|
||||||
|
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
|
||||||
|
|
||||||
|
There is only one type of task
|
||||||
|
- Open-ended task
|
||||||
|
|
||||||
|
## VQA v1.0 release ##
|
||||||
|
This release consists of
|
||||||
|
- Real
|
||||||
|
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
||||||
|
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
|
||||||
|
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
|
||||||
|
- Abstract
|
||||||
|
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
|
||||||
|
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
|
||||||
|
- 600,000 answers for training and 300,000 answers for validation (10 per question)
|
||||||
|
|
||||||
|
There are two types of tasks
|
||||||
|
- Open-ended task
|
||||||
|
- Multiple-choice task (18 choices per question)
|
||||||
|
|
||||||
|
## Requirements ##
|
||||||
|
- python 2.7
|
||||||
|
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
|
||||||
|
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
|
||||||
|
|
||||||
|
## Files ##
|
||||||
|
./Questions
|
||||||
|
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
||||||
|
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
||||||
|
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
||||||
|
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
|
||||||
|
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
|
||||||
|
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
|
||||||
|
|
||||||
|
./Annotations
|
||||||
|
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
||||||
|
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
||||||
|
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
||||||
|
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
|
||||||
|
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
|
||||||
|
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
|
||||||
|
|
||||||
|
./Images
|
||||||
|
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
|
||||||
|
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
|
||||||
|
|
||||||
|
./PythonHelperTools
|
||||||
|
- This directory contains the Python API to read and visualize the VQA dataset
|
||||||
|
- vqaDemo.py (demo script)
|
||||||
|
- vqaTools (API to read and visualize data)
|
||||||
|
|
||||||
|
./PythonEvaluationTools
|
||||||
|
- This directory contains the Python evaluation code
|
||||||
|
- vqaEvalDemo.py (evaluation demo script)
|
||||||
|
- vqaEvaluation (evaluation code)
|
||||||
|
|
||||||
|
./Results
|
||||||
|
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
|
||||||
|
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
|
||||||
|
|
||||||
|
./QuestionTypes
|
||||||
|
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
|
||||||
|
- mscoco_question_types.txt
|
||||||
|
- abstract_v002_question_types.txt
|
||||||
|
|
||||||
|
## References ##
|
||||||
|
- [VQA: Visual Question Answering](http://visualqa.org/)
|
||||||
|
- [Microsoft COCO](http://mscoco.org/)
|
||||||
|
|
||||||
|
## Developers ##
|
||||||
|
- Aishwarya Agrawal (Virginia Tech)
|
||||||
|
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
|
||||||
|
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
|
8
models/common/vqa_tools/__init__.py
Normal file
8
models/common/vqa_tools/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
__author__ = "aagrawal"
|
201
models/common/vqa_tools/aokvqa/LICENSE
Normal file
201
models/common/vqa_tools/aokvqa/LICENSE
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2022 Allen Institute for Artificial Intelligence
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
207
models/common/vqa_tools/aokvqa/README.md
Normal file
207
models/common/vqa_tools/aokvqa/README.md
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
# A-OKVQA
|
||||||
|
|
||||||
|
Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**.
|
||||||
|
|
||||||
|
Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public)
|
||||||
|
|
||||||
|
### Abstract
|
||||||
|
|
||||||
|
The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
<hr>
|
||||||
|
|
||||||
|
#### Table of Contents
|
||||||
|
|
||||||
|
- [Getting started](#getting-started)
|
||||||
|
* [Downloading the dataset](#downloading-the-dataset)
|
||||||
|
- [Evaluation & Leaderboard](#evaluation)
|
||||||
|
- [Codebase](#codebase)
|
||||||
|
* [Preparing data](#preparing-data)
|
||||||
|
* [Models and Predictions](#models-and-predictions)
|
||||||
|
|
||||||
|
<hr>
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git
|
||||||
|
|
||||||
|
cd aokvqa
|
||||||
|
export PYTHONPATH=.
|
||||||
|
|
||||||
|
conda env create --name aokvqa
|
||||||
|
conda activate aokvqa
|
||||||
|
```
|
||||||
|
|
||||||
|
### Downloading the dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export AOKVQA_DIR=./datasets/aokvqa/
|
||||||
|
mkdir -p ${AOKVQA_DIR}
|
||||||
|
|
||||||
|
curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details> <summary><b>Downloading COCO 2017</b></summary>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export COCO_DIR=./datasets/coco/
|
||||||
|
mkdir -p ${COCO_DIR}
|
||||||
|
|
||||||
|
for split in train val test; do
|
||||||
|
wget "http://images.cocodataset.org/zips/${split}2017.zip"
|
||||||
|
unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip"
|
||||||
|
done
|
||||||
|
|
||||||
|
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
|
||||||
|
unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
aokvqa_dir = os.getenv('AOKVQA_DIR')
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa, get_coco_path
|
||||||
|
train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test'
|
||||||
|
```
|
||||||
|
|
||||||
|
<details> <summary><b>Example dataset entry</b></summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
dataset_example = train_dataset[0]
|
||||||
|
|
||||||
|
print(dataset_example['question_id'])
|
||||||
|
# 22MexNkBPpdZGX6sxbxVBH
|
||||||
|
|
||||||
|
coco_dir = os.getenv('COCO_DIR')
|
||||||
|
image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)
|
||||||
|
print(image_path)
|
||||||
|
# ./datasets/coco/train2017/000000299207.jpg
|
||||||
|
|
||||||
|
print(dataset_example['question'])
|
||||||
|
print(dataset_example['choices'])
|
||||||
|
# What is the man by the bags awaiting?
|
||||||
|
# ['skateboarder', 'train', 'delivery', 'cab']
|
||||||
|
|
||||||
|
correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]
|
||||||
|
# Corrrect: cab
|
||||||
|
|
||||||
|
print(dataset_example['rationales'][0])
|
||||||
|
# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting.
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
'<question_id>' : {
|
||||||
|
'multiple_choice' : '<prediction>',
|
||||||
|
'direct_answer' : '<prediction>'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run evaluation on the validation set as follows.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Leaderboard
|
||||||
|
|
||||||
|
You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started).
|
||||||
|
|
||||||
|
## Codebase
|
||||||
|
|
||||||
|
We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3).
|
||||||
|
|
||||||
|
### Preparing data
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export FEATURES_DIR=./features/
|
||||||
|
mkdir -p ${FEATURES_DIR}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
|
||||||
|
|
||||||
|
for split in train val test; do
|
||||||
|
python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
<details> <summary><b>For training ClipCap with a transformer mapping network</b></summary>
|
||||||
|
|
||||||
|
If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details> <summary><b>For ResNet and BERT input features</b></summary>
|
||||||
|
|
||||||
|
Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# ResNet
|
||||||
|
for split in train val test; do
|
||||||
|
python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt
|
||||||
|
done
|
||||||
|
|
||||||
|
# BERT
|
||||||
|
for split in train val test; do
|
||||||
|
python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Models and Predictions
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LOG_DIR=./logs/
|
||||||
|
export PREDS_DIR=./predictions/
|
||||||
|
export PT_MODEL_DIR=./pretrained_models/
|
||||||
|
mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details> <summary><b>Download our pretrained model weights</b></summary>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Checkpoints for transfer learning experiments
|
||||||
|
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
|
||||||
|
|
||||||
|
# Checkpoints for ClipCap models (generating answers and rationales)
|
||||||
|
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
We have included instructions for replicating each of our experiments (see README.md files below).
|
||||||
|
|
||||||
|
All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above.
|
||||||
|
|
||||||
|
- [Heuristics](./heuristics/README.md)
|
||||||
|
- [Transfer Learning Experiments](./transfer_experiments/README.md)
|
||||||
|
- [Querying GPT-3](./gpt3/README.md)
|
||||||
|
- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
|
||||||
|
- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
|
||||||
|
|
||||||
|
For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set.
|
||||||
|
|
||||||
|
We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json
|
||||||
|
# repeat for test split ...
|
||||||
|
```
|
45
models/common/vqa_tools/aokvqa/data_scripts/build_vocab.py
Normal file
45
models/common/vqa_tools/aokvqa/data_scripts/build_vocab.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from collections import Counter
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# Build vocab from train set: correct choices + (direct answers appearing in >= 3 )
|
||||||
|
|
||||||
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
||||||
|
|
||||||
|
vocab = []
|
||||||
|
all_choices = Counter()
|
||||||
|
direct_answers = Counter()
|
||||||
|
|
||||||
|
for i in train_set:
|
||||||
|
vocab.append( i['choices'][i['correct_choice_idx']] )
|
||||||
|
all_choices.update(i['choices'])
|
||||||
|
direct_answers.update(set(i['direct_answers']))
|
||||||
|
vocab += [k for k,v in all_choices.items() if v >= 3]
|
||||||
|
vocab += [k for k,v in direct_answers.items() if v >= 3]
|
||||||
|
|
||||||
|
vocab = sorted(set(vocab))
|
||||||
|
print(f"Vocab size: {len(vocab)}")
|
||||||
|
|
||||||
|
# Save vocabulary Output
|
||||||
|
|
||||||
|
with open(args.output_file, 'w') as f:
|
||||||
|
for v in vocab:
|
||||||
|
print(v, file=f)
|
||||||
|
|
||||||
|
## Check validation set coverage
|
||||||
|
|
||||||
|
val_set = load_aokvqa(args.aokvqa_dir, 'val')
|
||||||
|
|
||||||
|
val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set]
|
||||||
|
val_acc = sum(val_acc) / len(val_acc) * 100
|
||||||
|
print(f"Val set coverage: {val_acc:.2f}" )
|
|
@ -0,0 +1,26 @@
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import clip
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file')
|
||||||
|
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
|
||||||
|
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.output_file.suffix == '.pt'
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model, preprocess = clip.load(args.model_type, device=device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
a = open(args.vocab_file).read().splitlines()
|
||||||
|
mc_text = clip.tokenize(a).to(device)
|
||||||
|
mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0]
|
||||||
|
mc_text_features = mc_text_features.float()
|
||||||
|
model_name = args.model_type.replace('/', '-').replace('@', '-')
|
||||||
|
torch.save(mc_text_features, args.output_file)
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.output_file.suffix == '.pt'
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
## Load model
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
|
||||||
|
model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
def mean_pooling(model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
## Encoding loop
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = {}
|
||||||
|
|
||||||
|
for d in tqdm(dataset):
|
||||||
|
encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt')
|
||||||
|
encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
|
||||||
|
e = mean_pooling(model(**encoded_input), encoded_input['attention_mask'])
|
||||||
|
embeddings[d['question_id']] = {
|
||||||
|
'question' : e[0].cpu()
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embeddings, args.output_file)
|
|
@ -0,0 +1,51 @@
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import clip
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa, get_coco_path
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
|
||||||
|
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.output_file.suffix == '.pt'
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
## Load model
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model, preprocess = clip.load(args.model_type, device=device)
|
||||||
|
|
||||||
|
## Encoding loop
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = {}
|
||||||
|
|
||||||
|
for d in tqdm(dataset):
|
||||||
|
q = d["question"]
|
||||||
|
q_text = clip.tokenize(q).to(device)
|
||||||
|
q_text_features = model.encode_text(q_text)
|
||||||
|
|
||||||
|
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir))
|
||||||
|
img = preprocess(img).unsqueeze(0).to(device)
|
||||||
|
image_features = model.encode_image(img)
|
||||||
|
|
||||||
|
embeddings[d['question_id']] = {
|
||||||
|
'question' : q_text_features[0].float().cpu(),
|
||||||
|
'image' : image_features[0].float().cpu(),
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embeddings, args.output_file)
|
|
@ -0,0 +1,62 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import models
|
||||||
|
from torchvision import transforms as T
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa, get_coco_path
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.output_file.suffix == '.pt'
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
## Load model
|
||||||
|
|
||||||
|
resnet_preprocess = T.Compose([
|
||||||
|
T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC),
|
||||||
|
T.CenterCrop(size=(224, 224)),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
resnet_model = models.resnet50(pretrained=True)
|
||||||
|
resnet_model = torch.nn.Sequential(
|
||||||
|
*list(resnet_model.children())[:-1],
|
||||||
|
nn.Flatten()
|
||||||
|
) # strip classification layer
|
||||||
|
resnet_model = resnet_model.to(device)
|
||||||
|
|
||||||
|
## Encoding loop
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = {}
|
||||||
|
|
||||||
|
for d in tqdm(dataset):
|
||||||
|
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB')
|
||||||
|
resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)
|
||||||
|
resnet_features = resnet_model(resnet_input)
|
||||||
|
embeddings[d['question_id']] = {
|
||||||
|
'image' : resnet_features[0].cpu()
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embeddings, args.output_file)
|
36
models/common/vqa_tools/aokvqa/environment.yml
Normal file
36
models/common/vqa_tools/aokvqa/environment.yml
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
name: aokvqa
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- huggingface
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.7
|
||||||
|
- cudatoolkit=11.3
|
||||||
|
- numpy=1.21.6
|
||||||
|
- pytorch=1.11.0
|
||||||
|
- torchvision=0.12.0
|
||||||
|
- pytorch-lightning=1.6.3
|
||||||
|
- torchmetrics=0.8.1
|
||||||
|
- gdown=4.4.0
|
||||||
|
- pip=22.0.4
|
||||||
|
- pip:
|
||||||
|
- argparse==1.4.0
|
||||||
|
- Pillow==9.0.1
|
||||||
|
- tensorboard==2.9.0
|
||||||
|
- ftfy==6.1.1
|
||||||
|
- regex==2022.3.15
|
||||||
|
- tqdm==4.64.0
|
||||||
|
- clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
|
||||||
|
- openai==0.18.1
|
||||||
|
- nltk==3.7
|
||||||
|
- sacrebleu==2.0.0
|
||||||
|
- sacremoses==0.0.53
|
||||||
|
- sentence-transformers==2.2.0
|
||||||
|
- datasets==2.1.0
|
||||||
|
- tokenizers==0.10.3
|
||||||
|
- transformers==4.10.3
|
||||||
|
|
||||||
|
# Next: resolve conflict between sentence-transfomers and pytorch-lightning
|
||||||
|
# pip uninstall sentencepiece
|
|
@ -0,0 +1,97 @@
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):
|
||||||
|
|
||||||
|
if isinstance(dataset, list):
|
||||||
|
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
|
||||||
|
|
||||||
|
if multiple_choice is False:
|
||||||
|
dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
dataset_qids = set(dataset.keys())
|
||||||
|
preds_qids = set(preds.keys())
|
||||||
|
assert dataset_qids.issubset(preds_qids)
|
||||||
|
|
||||||
|
# dataset = q_id (str) : dataset element (dict)
|
||||||
|
# preds = q_id (str) : prediction (str)
|
||||||
|
|
||||||
|
acc = []
|
||||||
|
|
||||||
|
for q in dataset.keys():
|
||||||
|
if q not in preds.keys():
|
||||||
|
acc.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred = preds[q]
|
||||||
|
choices = dataset[q]['choices']
|
||||||
|
direct_answers = dataset[q]['direct_answers']
|
||||||
|
|
||||||
|
## Multiple Choice setting
|
||||||
|
if multiple_choice:
|
||||||
|
if strict:
|
||||||
|
assert pred in choices, 'Prediction must be a valid choice'
|
||||||
|
correct_choice_idx = dataset[q]['correct_choice_idx']
|
||||||
|
acc.append( float(pred == choices[correct_choice_idx]) )
|
||||||
|
## Direct Answer setting
|
||||||
|
else:
|
||||||
|
num_match = sum([pred.lower() == da.lower() for da in direct_answers])
|
||||||
|
vqa_acc = min(1.0, num_match / 3.0)
|
||||||
|
acc.append(vqa_acc)
|
||||||
|
|
||||||
|
acc = sum(acc) / len(acc) * 100
|
||||||
|
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
for prediction_file in glob.glob(args.prediction_files):
|
||||||
|
predictions = json.load(open(prediction_file, 'r'))
|
||||||
|
|
||||||
|
# Multiple choice
|
||||||
|
|
||||||
|
mc_predictions = {}
|
||||||
|
|
||||||
|
for q in predictions.keys():
|
||||||
|
if 'multiple_choice' in predictions[q].keys():
|
||||||
|
mc_predictions[q] = predictions[q]['multiple_choice']
|
||||||
|
|
||||||
|
if mc_predictions != {}:
|
||||||
|
mc_acc = eval_aokvqa(
|
||||||
|
dataset,
|
||||||
|
mc_predictions,
|
||||||
|
multiple_choice=True,
|
||||||
|
strict=False
|
||||||
|
)
|
||||||
|
print(prediction_file, 'MC', mc_acc)
|
||||||
|
|
||||||
|
# Direct Answer
|
||||||
|
|
||||||
|
da_predictions = {}
|
||||||
|
|
||||||
|
for q in predictions.keys():
|
||||||
|
if 'direct_answer' in predictions[q].keys():
|
||||||
|
da_predictions[q] = predictions[q]['direct_answer']
|
||||||
|
|
||||||
|
if da_predictions != {}:
|
||||||
|
da_acc = eval_aokvqa(
|
||||||
|
dataset,
|
||||||
|
da_predictions,
|
||||||
|
multiple_choice=False,
|
||||||
|
strict=False
|
||||||
|
)
|
||||||
|
print(prediction_file, 'DA', da_acc)
|
13
models/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py
Normal file
13
models/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def load_aokvqa(aokvqa_dir, split, version='v1p0'):
|
||||||
|
assert split in ['train', 'val', 'test', 'test_w_ans']
|
||||||
|
dataset = json.load(open(
|
||||||
|
os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
|
||||||
|
))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def get_coco_path(split, image_id, coco_dir):
|
||||||
|
return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
|
|
@ -0,0 +1,31 @@
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file')
|
||||||
|
parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args.mc_pred_file or args.da_pred_file
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None
|
||||||
|
da_preds = json.load(args.da_pred_file) if args.da_pred_file else None
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
for d in dataset:
|
||||||
|
q = d['question_id']
|
||||||
|
predictions[q] = {}
|
||||||
|
if mc_preds and q in mc_preds.keys():
|
||||||
|
predictions[q]['multiple_choice'] = mc_preds[q]
|
||||||
|
if da_preds and q in da_preds.keys():
|
||||||
|
predictions[q]['direct_answer'] = da_preds[q]
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
|
@ -0,0 +1,44 @@
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sentence_transformers.util import cos_sim
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
def map_to_choices(dataset, predictions, device='cpu'):
|
||||||
|
if isinstance(dataset, list):
|
||||||
|
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
|
||||||
|
|
||||||
|
if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
|
||||||
|
model.to(device)
|
||||||
|
for q in tqdm(predictions.keys()):
|
||||||
|
choices = dataset[q]['choices']
|
||||||
|
if predictions[q] not in choices:
|
||||||
|
choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
|
||||||
|
a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
|
||||||
|
predictions[q] = choices[a_idx]
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
predictions = json.load(args.prediction_file)
|
||||||
|
predictions = map_to_choices(dataset, predictions)
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
14
models/common/vqa_tools/aokvqa/gpt3/README.md
Normal file
14
models/common/vqa_tools/aokvqa/gpt3/README.md
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
## Querying GPT-3
|
||||||
|
|
||||||
|
To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_ORG=....
|
||||||
|
export OPENAI_API_KEY=...
|
||||||
|
```
|
||||||
|
|
||||||
|
For producing predictions for both DA and MC settings, run:
|
||||||
|
```bash
|
||||||
|
python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json
|
||||||
|
python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json
|
||||||
|
```
|
23
models/common/vqa_tools/aokvqa/gpt3/caption_inputs.py
Normal file
23
models/common/vqa_tools/aokvqa/gpt3/caption_inputs.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val'], required=True)
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations']
|
||||||
|
coco_captions = {c['image_id'] : c['caption'] for c in coco_captions}
|
||||||
|
|
||||||
|
captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set }
|
||||||
|
|
||||||
|
json.dump(captions, args.output_file)
|
79
models/common/vqa_tools/aokvqa/gpt3/query_gpt3.py
Normal file
79
models/common/vqa_tools/aokvqa/gpt3/query_gpt3.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import openai
|
||||||
|
openai.organization = os.getenv('OPENAI_ORG')
|
||||||
|
openai.api_key = os.getenv('OPENAI_API_KEY')
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--n', type=int, default=10, dest='num_examples')
|
||||||
|
parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
|
||||||
|
parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
|
||||||
|
parser.add_argument('--include-choices', action='store_true', dest='include_choices')
|
||||||
|
parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
||||||
|
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
train_context = {}
|
||||||
|
context = {}
|
||||||
|
if args.context_file is not None:
|
||||||
|
train_context = json.load(args.train_context_file)
|
||||||
|
context = json.load(args.context_file)
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
for d in tqdm(eval_set):
|
||||||
|
q = d['question_id']
|
||||||
|
|
||||||
|
prompt = args.prompt_prefix
|
||||||
|
for e in random.sample(train_set, args.num_examples):
|
||||||
|
prompt += prompt_element(e,
|
||||||
|
context=train_context.get(q, None),
|
||||||
|
include_choices=args.include_choices,
|
||||||
|
answer=True
|
||||||
|
)
|
||||||
|
prompt += '\n\n'
|
||||||
|
|
||||||
|
prompt += prompt_element(d,
|
||||||
|
context=context.get(q, None),
|
||||||
|
include_choices=args.include_choices,
|
||||||
|
answer=False
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai.Completion.create(
|
||||||
|
engine="text-curie-001",
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions[q] = response.choices[0].text.strip()
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_element(d, context=None, include_choices=False, answer=False):
|
||||||
|
return (f"Context: {context}\n" if context is not None else '') + \
|
||||||
|
f"Q: {d['question']}\n" + \
|
||||||
|
(f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
|
||||||
|
f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
16
models/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py
Normal file
16
models/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set}
|
||||||
|
json.dump(rationales, args.output_file)
|
11
models/common/vqa_tools/aokvqa/heuristics/README.md
Normal file
11
models/common/vqa_tools/aokvqa/heuristics/README.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
## Heuristics
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# These scripts accept the same arguments.
|
||||||
|
# heuristics/random_unweighted.py
|
||||||
|
# heuristics/random_weighted.py
|
||||||
|
# heuristics/most_common_answer.py
|
||||||
|
|
||||||
|
python heuristics/random_unweighted.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc --out ${PREDS_DIR}/random-unweighted_val-mc.json
|
||||||
|
# Exclude --mc for the direct answer setting
|
||||||
|
```
|
|
@ -0,0 +1,39 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
||||||
|
train_freq = dict(Counter(
|
||||||
|
[d['choices'][d['correct_choice_idx']] for d in train_set]
|
||||||
|
))
|
||||||
|
most_common_answer = max(train_freq.keys(), key=train_freq.get)
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
for d in eval_set:
|
||||||
|
q = d['question_id']
|
||||||
|
predictions[q] = most_common_answer
|
||||||
|
|
||||||
|
if args.multiple_choice:
|
||||||
|
choices = [c for c in d['choices'] if c in train_freq.keys()]
|
||||||
|
if len(choices) > 0:
|
||||||
|
predictions[q] = max(choices, key=train_freq.get)
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
|
@ -0,0 +1,38 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from random import seed, sample
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed(0)
|
||||||
|
|
||||||
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
||||||
|
|
||||||
|
if args.multiple_choice is False:
|
||||||
|
choices = list(set(
|
||||||
|
[d['choices'][d['correct_choice_idx']] for d in train_set]
|
||||||
|
))
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
for d in eval_set:
|
||||||
|
q = d['question_id']
|
||||||
|
if args.multiple_choice:
|
||||||
|
choices = d['choices']
|
||||||
|
predictions[q] = sample(choices, 1)[0]
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
46
models/common/vqa_tools/aokvqa/heuristics/random_weighted.py
Normal file
46
models/common/vqa_tools/aokvqa/heuristics/random_weighted.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
||||||
|
train_freq = dict(Counter(
|
||||||
|
[d['choices'][d['correct_choice_idx']] for d in train_set]
|
||||||
|
))
|
||||||
|
|
||||||
|
if args.multiple_choice is False:
|
||||||
|
choices = list(train_freq.keys())
|
||||||
|
probs = [f / len(train_set) for f in train_freq.values()]
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
for d in eval_set:
|
||||||
|
if args.multiple_choice:
|
||||||
|
choices = d['choices']
|
||||||
|
probs = [train_freq.get(c, 0) for c in choices]
|
||||||
|
if probs == [0, 0, 0, 0]:
|
||||||
|
probs = [1, 1, 1, 1]
|
||||||
|
probs = [p / sum(probs) for p in probs]
|
||||||
|
|
||||||
|
q = d['question_id']
|
||||||
|
predictions[q] = np.random.choice(choices, size=1, p=probs)[0]
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
13
models/common/vqa_tools/aokvqa/load_aokvqa.py
Normal file
13
models/common/vqa_tools/aokvqa/load_aokvqa.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def load_aokvqa(aokvqa_dir, split, version='v1p0'):
|
||||||
|
assert split in ['train', 'val', 'test', 'test_w_ans']
|
||||||
|
dataset = json.load(open(
|
||||||
|
os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
|
||||||
|
))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def get_coco_path(split, image_id, coco_dir):
|
||||||
|
return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
|
|
@ -0,0 +1,41 @@
|
||||||
|
## Transfer Learning Experiments
|
||||||
|
|
||||||
|
We use the following training/prediction scripts for the classifier, zero-shot, and contrastive experiments in Table 3.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
## Training
|
||||||
|
python transfer_experiments/train.py --aokvqa-dir ${AOKVQA_DIR} --vocab ${AOKVQA_DIR}/large_vocab_train.csv --log-dir ${LOG_DIR}
|
||||||
|
|
||||||
|
--backbone clip --clip-model-type ViT-B/32 --train-features ${FEATURES_DIR}/clip-ViT-B-32_train.pt --val-features ${FEATURES_DIR}/clip-ViT-B-32_val.pt
|
||||||
|
--inputs question # OR --inputs image # OR --inputs question image
|
||||||
|
# OR
|
||||||
|
--backbone resnet --train-features ${FEATURES_DIR}/resnet_train.pt --val-features ${FEATURES_DIR}/resnet_val.pt --inputs image
|
||||||
|
# OR
|
||||||
|
--backbone bert --train-features ${FEATURES_DIR}/bert_train.pt --val-features ${FEATURES_DIR}/bert_val.pt --inputs question
|
||||||
|
|
||||||
|
--objective classifier
|
||||||
|
# OR
|
||||||
|
--objective contrastive --vocab-features ${FEATURE_DIR}/clip-ViT-B-32_large_vocab.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
You can make predictions for CLIP zero-shot or from a classifier/contrastive checkpoint trained above.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
## Predicting
|
||||||
|
python transfer_experiments/predict.py --aokvqa-dir ${AOKVQA_DIR} --out ${PREDS_DIR}/clip-classifier_val-mc.json
|
||||||
|
|
||||||
|
--split val # or test
|
||||||
|
--features ${FEATURE_DIR}/clip-ViT-B-32_val.pt # adjust for backbone and eval split
|
||||||
|
|
||||||
|
--ckpt path/to/model.ckpt
|
||||||
|
# OR
|
||||||
|
--zero-shot --clip-model-type ViT-B/32
|
||||||
|
--inputs question # OR --inputs image # OR --inputs question image
|
||||||
|
|
||||||
|
--mc # Multiple-choice. Exclude for direct-answer.
|
||||||
|
|
||||||
|
# IF classifier OR direct-answer
|
||||||
|
--vocab ${AOKVQA_DIR}/large_vocab_train.csv
|
||||||
|
# IF contrastive/zero-shot AND direct-answer
|
||||||
|
--vocab-features ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
|
||||||
|
```
|
126
models/common/vqa_tools/aokvqa/transfer_experiments/predict.py
Normal file
126
models/common/vqa_tools/aokvqa/transfer_experiments/predict.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
|
||||||
|
import sentencepiece; import pytorch_lightning as pl; import clip
|
||||||
|
|
||||||
|
from transfer_experiments.train import LinearClassifier
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
from evaluation.remap_predictions import map_to_choices
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--features', type=pathlib.Path, required=True)
|
||||||
|
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
|
||||||
|
#
|
||||||
|
parser_weights = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
|
||||||
|
parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')
|
||||||
|
|
||||||
|
parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
|
||||||
|
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
|
||||||
|
#
|
||||||
|
parser.add_argument('--vocab', type=argparse.FileType('r'))
|
||||||
|
parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
|
||||||
|
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
|
||||||
|
|
||||||
|
parser.add_argument('--clip-model-type', type=str,
|
||||||
|
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
|
||||||
|
dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
|
||||||
|
#
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
|
||||||
|
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
||||||
|
|
||||||
|
## Load models
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
if args.checkpoint_path is not None:
|
||||||
|
classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
|
||||||
|
classifier.to(device)
|
||||||
|
hp = classifier.hparams
|
||||||
|
elif args.clip_zero_shot:
|
||||||
|
classifier = nn.Identity().to(device)
|
||||||
|
hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)
|
||||||
|
|
||||||
|
# Load input features
|
||||||
|
|
||||||
|
embeddings = torch.load(args.features)
|
||||||
|
if hp.backbone == 'clip':
|
||||||
|
for q in embeddings.keys():
|
||||||
|
embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
|
||||||
|
embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Load vocab, vocab features, clip
|
||||||
|
|
||||||
|
if (hp.objective == 'classifier') or \
|
||||||
|
(hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
|
||||||
|
vocab = args.vocab.read().splitlines()
|
||||||
|
|
||||||
|
if hp.objective in ['contrastive', 'zero-shot']:
|
||||||
|
if args.multiple_choice is False:
|
||||||
|
vocab_features = torch.load(args.vocab_features).cpu()
|
||||||
|
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
clip_model = clip.load(hp.clip_model_type, device=device)[0]
|
||||||
|
logit_scale = clip_model.logit_scale.exp().cpu()
|
||||||
|
|
||||||
|
## Prediction loop
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for o in tqdm(aokvqa_set):
|
||||||
|
q = o['question_id']
|
||||||
|
|
||||||
|
# Load input embedding (from question / image)
|
||||||
|
if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
|
||||||
|
e = embeddings[q]['question'] + embeddings[q]['image']
|
||||||
|
elif 'question' in hp.inputs and 'image' in hp.inputs:
|
||||||
|
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
|
||||||
|
elif 'question' in hp.inputs:
|
||||||
|
e = embeddings[q]['question']
|
||||||
|
elif 'image' in hp.inputs:
|
||||||
|
e = embeddings[q]['image']
|
||||||
|
|
||||||
|
# Pass inputs through model
|
||||||
|
e = e.unsqueeze(0).to(device)
|
||||||
|
x = classifier(e)[0].cpu()
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
if hp.objective in ['contrastive', 'zero-shot']:
|
||||||
|
if args.multiple_choice:
|
||||||
|
vocab = o['choices']
|
||||||
|
# Encode choices
|
||||||
|
vocab_features = clip.tokenize(vocab).to(device)
|
||||||
|
vocab_features = torch.stack([
|
||||||
|
clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
|
||||||
|
], dim=1)[0]
|
||||||
|
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
|
||||||
|
vocab_features = vocab_features.float().cpu()
|
||||||
|
|
||||||
|
x = logit_scale * x @ vocab_features.t()
|
||||||
|
x = x.softmax(dim=-1)
|
||||||
|
|
||||||
|
predictions[q] = vocab[x.argmax().item()]
|
||||||
|
|
||||||
|
## Save and evaluate predictions
|
||||||
|
|
||||||
|
# Map prediction to nearest neighbor choice (by word embeddings)
|
||||||
|
if args.multiple_choice and hp.objective == 'classifier':
|
||||||
|
predictions = map_to_choices(aokvqa_set, predictions)
|
||||||
|
|
||||||
|
json.dump(predictions, args.output_file)
|
263
models/common/vqa_tools/aokvqa/transfer_experiments/train.py
Normal file
263
models/common/vqa_tools/aokvqa/transfer_experiments/train.py
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
|
||||||
|
import sentencepiece; import pytorch_lightning as pl
|
||||||
|
|
||||||
|
import torchmetrics.functional as MF
|
||||||
|
|
||||||
|
from load_aokvqa import load_aokvqa
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
||||||
|
parser.add_argument('--vocab', type=argparse.FileType('r'), required=True)
|
||||||
|
parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True)
|
||||||
|
#
|
||||||
|
parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True)
|
||||||
|
parser.add_argument('--clip-model-type', type=str,
|
||||||
|
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
|
||||||
|
dest='clip_model_type', required=('clip' in sys.argv))
|
||||||
|
parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features')
|
||||||
|
parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features')
|
||||||
|
parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features')
|
||||||
|
#
|
||||||
|
parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True)
|
||||||
|
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True)
|
||||||
|
# Defaults
|
||||||
|
parser.add_argument('--bs', type=int, default=128, dest='batch_size')
|
||||||
|
parser.add_argument('--lr', type=float, default=0.01)
|
||||||
|
parser.add_argument('--epochs', type=int, default=500)
|
||||||
|
parser.add_argument('--gpus', type=int, default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
pl.seed_everything(1)
|
||||||
|
vocab = args.vocab.read().splitlines()
|
||||||
|
|
||||||
|
## Data loading
|
||||||
|
|
||||||
|
dm = AokvqaEmbeddingsDataModule(
|
||||||
|
args.aokvqa_dir,
|
||||||
|
args.train_features,
|
||||||
|
args.val_features,
|
||||||
|
args.objective,
|
||||||
|
args.backbone,
|
||||||
|
args.inputs,
|
||||||
|
vocab,
|
||||||
|
args.vocab_features,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=16
|
||||||
|
)
|
||||||
|
|
||||||
|
## Model definition
|
||||||
|
|
||||||
|
model = LinearClassifier(
|
||||||
|
args.objective,
|
||||||
|
args.backbone,
|
||||||
|
args.clip_model_type,
|
||||||
|
args.inputs,
|
||||||
|
len(vocab),
|
||||||
|
args.lr
|
||||||
|
)
|
||||||
|
|
||||||
|
## Training and testing loops
|
||||||
|
|
||||||
|
logger = pl.loggers.TensorBoardLogger(
|
||||||
|
args.log_dir,
|
||||||
|
name=f'{args.backbone}-{args.objective}',
|
||||||
|
version=f"inputs:{'+'.join(args.inputs)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
logger=logger,
|
||||||
|
gpus=args.gpus,
|
||||||
|
max_epochs=args.epochs,
|
||||||
|
callbacks=[
|
||||||
|
pl.callbacks.ModelCheckpoint(
|
||||||
|
monitor="val_acc",
|
||||||
|
filename="{epoch:02d}-{val_acc:.2f}",
|
||||||
|
mode="max"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model, dm)
|
||||||
|
|
||||||
|
|
||||||
|
class AokvqaEmbeddingsDataset(Dataset):
|
||||||
|
def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features):
|
||||||
|
|
||||||
|
aokvqa_set = load_aokvqa(aokvqa_dir, split)
|
||||||
|
|
||||||
|
assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \
|
||||||
|
or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \
|
||||||
|
or ( backbone == 'clip' )
|
||||||
|
|
||||||
|
embeddings = torch.load(input_features)
|
||||||
|
if backbone == 'clip':
|
||||||
|
for q in embeddings.keys():
|
||||||
|
embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True)
|
||||||
|
embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True)
|
||||||
|
if objective == 'contrastive':
|
||||||
|
vocab_embeddings = torch.load(vocab_features)
|
||||||
|
vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
self.objective = objective
|
||||||
|
self.vocab_len = len(vocab)
|
||||||
|
|
||||||
|
self.embeddings = []
|
||||||
|
self.answers = []
|
||||||
|
|
||||||
|
for o in aokvqa_set:
|
||||||
|
correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers'])
|
||||||
|
correct_answers = [vocab.index(a) for a in correct_answers if a in vocab]
|
||||||
|
if self.objective == 'contrastive':
|
||||||
|
correct_answers = [vocab_embeddings[a] for a in correct_answers]
|
||||||
|
if len(correct_answers) == 0: continue
|
||||||
|
self.answers.append(correct_answers)
|
||||||
|
|
||||||
|
q = o['question_id']
|
||||||
|
if 'question' in inputs and 'image' in inputs:
|
||||||
|
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
|
||||||
|
elif 'question' in inputs and 'image' not in inputs:
|
||||||
|
e = embeddings[q]['question']
|
||||||
|
elif 'question' not in inputs and 'image' in inputs:
|
||||||
|
e = embeddings[q]['image']
|
||||||
|
self.embeddings.append(e)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
e = self.embeddings[index]
|
||||||
|
a = self.answers[index]
|
||||||
|
if self.objective == 'classifier':
|
||||||
|
a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0)
|
||||||
|
elif self.objective == 'contrastive':
|
||||||
|
a = random.sample(a, 1)[0]
|
||||||
|
return e, a
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
class AokvqaEmbeddingsDataModule(pl.LightningDataModule):
|
||||||
|
|
||||||
|
def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0):
|
||||||
|
super().__init__()
|
||||||
|
self.aokvqa_dir = aokvqa_dir
|
||||||
|
self.train_features = train_features
|
||||||
|
self.val_features = val_features
|
||||||
|
self.objective = objective
|
||||||
|
self.backbone = backbone
|
||||||
|
self.inputs = inputs
|
||||||
|
self.vocab = vocab
|
||||||
|
self.vocab_features = vocab_features
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
def setup(self, stage=None):
|
||||||
|
self.train_dataset = AokvqaEmbeddingsDataset(
|
||||||
|
self.aokvqa_dir, 'train', self.train_features, self.objective,
|
||||||
|
self.backbone, self.inputs, self.vocab, self.vocab_features
|
||||||
|
)
|
||||||
|
self.val_dataset = AokvqaEmbeddingsDataset(
|
||||||
|
self.aokvqa_dir, 'val', self.val_features, self.objective,
|
||||||
|
self.backbone, self.inputs, self.vocab, self.vocab_features
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
||||||
|
num_workers=int(0.8 * self.num_workers)
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
self.val_dataset, batch_size=self.batch_size, shuffle=False,
|
||||||
|
num_workers=int(0.2 * self.num_workers)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearClassifier(pl.LightningModule):
|
||||||
|
def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001):
|
||||||
|
super().__init__()
|
||||||
|
self.save_hyperparameters(ignore=['lr'])
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
if self.hparams.backbone == 'clip':
|
||||||
|
clip_dim = {
|
||||||
|
'RN50' : 1024,
|
||||||
|
'RN50x4' : 640,
|
||||||
|
'RN50x16' : 768,
|
||||||
|
'RN50x64' : 1024,
|
||||||
|
'RN101' : 512,
|
||||||
|
'ViT-B/32' : 512,
|
||||||
|
'ViT-B/16' : 512,
|
||||||
|
'ViT-L/14' : 768,
|
||||||
|
'ViT-L/14@336px' : 768,
|
||||||
|
}[clip_model_type]
|
||||||
|
emb_dim = clip_dim * len(inputs)
|
||||||
|
elif self.hparams.backbone == 'resnet':
|
||||||
|
emb_dim = 2048
|
||||||
|
elif self.hparams.backbone == 'bert':
|
||||||
|
emb_dim = 768
|
||||||
|
|
||||||
|
if self.hparams.objective == 'classifier':
|
||||||
|
out_dim = vocab_len
|
||||||
|
elif self.hparams.objective == 'contrastive':
|
||||||
|
out_dim = clip_dim
|
||||||
|
|
||||||
|
self.linear = nn.Linear(emb_dim, out_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear(x)
|
||||||
|
if self.hparams.objective == 'classifier':
|
||||||
|
x = torch.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
x, y = batch
|
||||||
|
|
||||||
|
y_pred = self.forward(x)
|
||||||
|
|
||||||
|
if self.hparams.objective == 'classifier':
|
||||||
|
loss = F.binary_cross_entropy(y_pred, y.float())
|
||||||
|
elif self.hparams.objective == 'contrastive':
|
||||||
|
indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device)
|
||||||
|
sim = (y_pred @ y.T).softmax(dim=-1)
|
||||||
|
loss = F.cross_entropy(sim, indices)
|
||||||
|
|
||||||
|
if self.hparams.objective == 'classifier':
|
||||||
|
acc = MF.f1_score(y_pred, y)
|
||||||
|
elif self.hparams.objective == 'contrastive':
|
||||||
|
acc = torch.mean(sim[indices, indices])
|
||||||
|
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
loss, acc = self.compute_loss(batch)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
self.log("train_acc", acc)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
loss, acc = self.compute_loss(batch)
|
||||||
|
self.log("val_loss", loss)
|
||||||
|
self.log("val_acc", acc)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
211
models/common/vqa_tools/vqa.py
Normal file
211
models/common/vqa_tools/vqa.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
__author__ = "aagrawal"
|
||||||
|
__version__ = "0.9"
|
||||||
|
|
||||||
|
# Interface for accessing the VQA dataset.
|
||||||
|
|
||||||
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||||||
|
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
||||||
|
|
||||||
|
# The following functions are defined:
|
||||||
|
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
||||||
|
# getQuesIds - Get question ids that satisfy given filter conditions.
|
||||||
|
# getImgIds - Get image ids that satisfy given filter conditions.
|
||||||
|
# loadQA - Load questions and answers with the specified question ids.
|
||||||
|
# showQA - Display the specified questions and answers.
|
||||||
|
# loadRes - Load result file and create result object.
|
||||||
|
|
||||||
|
# Help on each function can be accessed by: "help(COCO.function)"
|
||||||
|
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class VQA:
|
||||||
|
def __init__(self, annotation_file=None, question_file=None):
|
||||||
|
"""
|
||||||
|
Constructor of VQA helper class for reading and visualizing questions and answers.
|
||||||
|
:param annotation_file (str): location of VQA annotation file
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# load dataset
|
||||||
|
self.dataset = {}
|
||||||
|
self.questions = {}
|
||||||
|
self.qa = {}
|
||||||
|
self.qqa = {}
|
||||||
|
self.imgToQA = {}
|
||||||
|
if not annotation_file == None and not question_file == None:
|
||||||
|
print("loading VQA annotations and questions into memory...")
|
||||||
|
time_t = datetime.datetime.utcnow()
|
||||||
|
dataset = json.load(open(annotation_file, "r"))
|
||||||
|
questions = json.load(open(question_file, "r"))
|
||||||
|
self.dataset = dataset
|
||||||
|
self.questions = questions
|
||||||
|
self.createIndex()
|
||||||
|
|
||||||
|
def createIndex(self):
|
||||||
|
# create index
|
||||||
|
print("creating index...")
|
||||||
|
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
|
||||||
|
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
||||||
|
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
||||||
|
for ann in self.dataset["annotations"]:
|
||||||
|
imgToQA[ann["image_id"]] += [ann]
|
||||||
|
qa[ann["question_id"]] = ann
|
||||||
|
for ques in self.questions["questions"]:
|
||||||
|
qqa[ques["question_id"]] = ques
|
||||||
|
print("index created!")
|
||||||
|
|
||||||
|
# create class members
|
||||||
|
self.qa = qa
|
||||||
|
self.qqa = qqa
|
||||||
|
self.imgToQA = imgToQA
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
"""
|
||||||
|
Print information about the VQA annotation file.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for key, value in self.datset["info"].items():
|
||||||
|
print("%s: %s" % (key, value))
|
||||||
|
|
||||||
|
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
||||||
|
"""
|
||||||
|
Get question ids that satisfy given filter conditions. default skips that filter
|
||||||
|
:param imgIds (int array) : get question ids for given imgs
|
||||||
|
quesTypes (str array) : get question ids for given question types
|
||||||
|
ansTypes (str array) : get question ids for given answer types
|
||||||
|
:return: ids (int array) : integer array of question ids
|
||||||
|
"""
|
||||||
|
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
||||||
|
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
||||||
|
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
||||||
|
|
||||||
|
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
||||||
|
anns = self.dataset["annotations"]
|
||||||
|
else:
|
||||||
|
if not len(imgIds) == 0:
|
||||||
|
anns = sum(
|
||||||
|
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
anns = self.dataset["annotations"]
|
||||||
|
anns = (
|
||||||
|
anns
|
||||||
|
if len(quesTypes) == 0
|
||||||
|
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
||||||
|
)
|
||||||
|
anns = (
|
||||||
|
anns
|
||||||
|
if len(ansTypes) == 0
|
||||||
|
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
||||||
|
)
|
||||||
|
ids = [ann["question_id"] for ann in anns]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
||||||
|
"""
|
||||||
|
Get image ids that satisfy given filter conditions. default skips that filter
|
||||||
|
:param quesIds (int array) : get image ids for given question ids
|
||||||
|
quesTypes (str array) : get image ids for given question types
|
||||||
|
ansTypes (str array) : get image ids for given answer types
|
||||||
|
:return: ids (int array) : integer array of image ids
|
||||||
|
"""
|
||||||
|
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
||||||
|
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
||||||
|
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
||||||
|
|
||||||
|
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
||||||
|
anns = self.dataset["annotations"]
|
||||||
|
else:
|
||||||
|
if not len(quesIds) == 0:
|
||||||
|
anns = sum(
|
||||||
|
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
anns = self.dataset["annotations"]
|
||||||
|
anns = (
|
||||||
|
anns
|
||||||
|
if len(quesTypes) == 0
|
||||||
|
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
||||||
|
)
|
||||||
|
anns = (
|
||||||
|
anns
|
||||||
|
if len(ansTypes) == 0
|
||||||
|
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
||||||
|
)
|
||||||
|
ids = [ann["image_id"] for ann in anns]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def loadQA(self, ids=[]):
|
||||||
|
"""
|
||||||
|
Load questions and answers with the specified question ids.
|
||||||
|
:param ids (int array) : integer ids specifying question ids
|
||||||
|
:return: qa (object array) : loaded qa objects
|
||||||
|
"""
|
||||||
|
if type(ids) == list:
|
||||||
|
return [self.qa[id] for id in ids]
|
||||||
|
elif type(ids) == int:
|
||||||
|
return [self.qa[ids]]
|
||||||
|
|
||||||
|
def showQA(self, anns):
|
||||||
|
"""
|
||||||
|
Display the specified annotations.
|
||||||
|
:param anns (array of object): annotations to display
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if len(anns) == 0:
|
||||||
|
return 0
|
||||||
|
for ann in anns:
|
||||||
|
quesId = ann["question_id"]
|
||||||
|
print("Question: %s" % (self.qqa[quesId]["question"]))
|
||||||
|
for ans in ann["answers"]:
|
||||||
|
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
|
||||||
|
|
||||||
|
def loadRes(self, resFile, quesFile):
|
||||||
|
"""
|
||||||
|
Load result file and return a result object.
|
||||||
|
:param resFile (str) : file name of result file
|
||||||
|
:return: res (obj) : result api object
|
||||||
|
"""
|
||||||
|
res = VQA()
|
||||||
|
res.questions = json.load(open(quesFile))
|
||||||
|
res.dataset["info"] = copy.deepcopy(self.questions["info"])
|
||||||
|
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
|
||||||
|
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
|
||||||
|
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
|
||||||
|
res.dataset["license"] = copy.deepcopy(self.questions["license"])
|
||||||
|
|
||||||
|
print("Loading and preparing results... ")
|
||||||
|
time_t = datetime.datetime.utcnow()
|
||||||
|
anns = json.load(open(resFile))
|
||||||
|
assert type(anns) == list, "results is not an array of objects"
|
||||||
|
annsQuesIds = [ann["question_id"] for ann in anns]
|
||||||
|
assert set(annsQuesIds) == set(
|
||||||
|
self.getQuesIds()
|
||||||
|
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
|
||||||
|
for ann in anns:
|
||||||
|
quesId = ann["question_id"]
|
||||||
|
if res.dataset["task_type"] == "Multiple Choice":
|
||||||
|
assert (
|
||||||
|
ann["answer"] in self.qqa[quesId]["multiple_choices"]
|
||||||
|
), "predicted answer is not one of the multiple choices"
|
||||||
|
qaAnn = self.qa[quesId]
|
||||||
|
ann["image_id"] = qaAnn["image_id"]
|
||||||
|
ann["question_type"] = qaAnn["question_type"]
|
||||||
|
ann["answer_type"] = qaAnn["answer_type"]
|
||||||
|
print(
|
||||||
|
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
|
||||||
|
)
|
||||||
|
|
||||||
|
res.dataset["annotations"] = anns
|
||||||
|
res.createIndex()
|
||||||
|
return res
|
324
models/common/vqa_tools/vqa_eval.py
Normal file
324
models/common/vqa_tools/vqa_eval.py
Normal file
|
@ -0,0 +1,324 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
__author__ = "aagrawal"
|
||||||
|
|
||||||
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||||||
|
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class VQAEval:
|
||||||
|
def __init__(self, vqa=None, vqaRes=None, n=2):
|
||||||
|
self.n = n
|
||||||
|
self.accuracy = {}
|
||||||
|
self.evalQA = {}
|
||||||
|
self.evalQuesType = {}
|
||||||
|
self.evalAnsType = {}
|
||||||
|
self.vqa = vqa
|
||||||
|
self.vqaRes = vqaRes
|
||||||
|
if vqa is not None:
|
||||||
|
self.params = {"question_id": vqa.getQuesIds()}
|
||||||
|
self.contractions = {
|
||||||
|
"aint": "ain't",
|
||||||
|
"arent": "aren't",
|
||||||
|
"cant": "can't",
|
||||||
|
"couldve": "could've",
|
||||||
|
"couldnt": "couldn't",
|
||||||
|
"couldn'tve": "couldn't've",
|
||||||
|
"couldnt've": "couldn't've",
|
||||||
|
"didnt": "didn't",
|
||||||
|
"doesnt": "doesn't",
|
||||||
|
"dont": "don't",
|
||||||
|
"hadnt": "hadn't",
|
||||||
|
"hadnt've": "hadn't've",
|
||||||
|
"hadn'tve": "hadn't've",
|
||||||
|
"hasnt": "hasn't",
|
||||||
|
"havent": "haven't",
|
||||||
|
"hed": "he'd",
|
||||||
|
"hed've": "he'd've",
|
||||||
|
"he'dve": "he'd've",
|
||||||
|
"hes": "he's",
|
||||||
|
"howd": "how'd",
|
||||||
|
"howll": "how'll",
|
||||||
|
"hows": "how's",
|
||||||
|
"Id've": "I'd've",
|
||||||
|
"I'dve": "I'd've",
|
||||||
|
"Im": "I'm",
|
||||||
|
"Ive": "I've",
|
||||||
|
"isnt": "isn't",
|
||||||
|
"itd": "it'd",
|
||||||
|
"itd've": "it'd've",
|
||||||
|
"it'dve": "it'd've",
|
||||||
|
"itll": "it'll",
|
||||||
|
"let's": "let's",
|
||||||
|
"maam": "ma'am",
|
||||||
|
"mightnt": "mightn't",
|
||||||
|
"mightnt've": "mightn't've",
|
||||||
|
"mightn'tve": "mightn't've",
|
||||||
|
"mightve": "might've",
|
||||||
|
"mustnt": "mustn't",
|
||||||
|
"mustve": "must've",
|
||||||
|
"neednt": "needn't",
|
||||||
|
"notve": "not've",
|
||||||
|
"oclock": "o'clock",
|
||||||
|
"oughtnt": "oughtn't",
|
||||||
|
"ow's'at": "'ow's'at",
|
||||||
|
"'ows'at": "'ow's'at",
|
||||||
|
"'ow'sat": "'ow's'at",
|
||||||
|
"shant": "shan't",
|
||||||
|
"shed've": "she'd've",
|
||||||
|
"she'dve": "she'd've",
|
||||||
|
"she's": "she's",
|
||||||
|
"shouldve": "should've",
|
||||||
|
"shouldnt": "shouldn't",
|
||||||
|
"shouldnt've": "shouldn't've",
|
||||||
|
"shouldn'tve": "shouldn't've",
|
||||||
|
"somebody'd": "somebodyd",
|
||||||
|
"somebodyd've": "somebody'd've",
|
||||||
|
"somebody'dve": "somebody'd've",
|
||||||
|
"somebodyll": "somebody'll",
|
||||||
|
"somebodys": "somebody's",
|
||||||
|
"someoned": "someone'd",
|
||||||
|
"someoned've": "someone'd've",
|
||||||
|
"someone'dve": "someone'd've",
|
||||||
|
"someonell": "someone'll",
|
||||||
|
"someones": "someone's",
|
||||||
|
"somethingd": "something'd",
|
||||||
|
"somethingd've": "something'd've",
|
||||||
|
"something'dve": "something'd've",
|
||||||
|
"somethingll": "something'll",
|
||||||
|
"thats": "that's",
|
||||||
|
"thered": "there'd",
|
||||||
|
"thered've": "there'd've",
|
||||||
|
"there'dve": "there'd've",
|
||||||
|
"therere": "there're",
|
||||||
|
"theres": "there's",
|
||||||
|
"theyd": "they'd",
|
||||||
|
"theyd've": "they'd've",
|
||||||
|
"they'dve": "they'd've",
|
||||||
|
"theyll": "they'll",
|
||||||
|
"theyre": "they're",
|
||||||
|
"theyve": "they've",
|
||||||
|
"twas": "'twas",
|
||||||
|
"wasnt": "wasn't",
|
||||||
|
"wed've": "we'd've",
|
||||||
|
"we'dve": "we'd've",
|
||||||
|
"weve": "we've",
|
||||||
|
"werent": "weren't",
|
||||||
|
"whatll": "what'll",
|
||||||
|
"whatre": "what're",
|
||||||
|
"whats": "what's",
|
||||||
|
"whatve": "what've",
|
||||||
|
"whens": "when's",
|
||||||
|
"whered": "where'd",
|
||||||
|
"wheres": "where's",
|
||||||
|
"whereve": "where've",
|
||||||
|
"whod": "who'd",
|
||||||
|
"whod've": "who'd've",
|
||||||
|
"who'dve": "who'd've",
|
||||||
|
"wholl": "who'll",
|
||||||
|
"whos": "who's",
|
||||||
|
"whove": "who've",
|
||||||
|
"whyll": "why'll",
|
||||||
|
"whyre": "why're",
|
||||||
|
"whys": "why's",
|
||||||
|
"wont": "won't",
|
||||||
|
"wouldve": "would've",
|
||||||
|
"wouldnt": "wouldn't",
|
||||||
|
"wouldnt've": "wouldn't've",
|
||||||
|
"wouldn'tve": "wouldn't've",
|
||||||
|
"yall": "y'all",
|
||||||
|
"yall'll": "y'all'll",
|
||||||
|
"y'allll": "y'all'll",
|
||||||
|
"yall'd've": "y'all'd've",
|
||||||
|
"y'alld've": "y'all'd've",
|
||||||
|
"y'all'dve": "y'all'd've",
|
||||||
|
"youd": "you'd",
|
||||||
|
"youd've": "you'd've",
|
||||||
|
"you'dve": "you'd've",
|
||||||
|
"youll": "you'll",
|
||||||
|
"youre": "you're",
|
||||||
|
"youve": "you've",
|
||||||
|
}
|
||||||
|
self.manualMap = {
|
||||||
|
"none": "0",
|
||||||
|
"zero": "0",
|
||||||
|
"one": "1",
|
||||||
|
"two": "2",
|
||||||
|
"three": "3",
|
||||||
|
"four": "4",
|
||||||
|
"five": "5",
|
||||||
|
"six": "6",
|
||||||
|
"seven": "7",
|
||||||
|
"eight": "8",
|
||||||
|
"nine": "9",
|
||||||
|
"ten": "10",
|
||||||
|
}
|
||||||
|
self.articles = ["a", "an", "the"]
|
||||||
|
|
||||||
|
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
||||||
|
self.commaStrip = re.compile("(\d)(,)(\d)")
|
||||||
|
self.punct = [
|
||||||
|
";",
|
||||||
|
r"/",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
'"',
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"=",
|
||||||
|
"+",
|
||||||
|
"\\",
|
||||||
|
"_",
|
||||||
|
"-",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"@",
|
||||||
|
"`",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
]
|
||||||
|
|
||||||
|
def evaluate(self, quesIds=None):
|
||||||
|
if quesIds == None:
|
||||||
|
quesIds = [quesId for quesId in self.params["question_id"]]
|
||||||
|
gts = {}
|
||||||
|
res = {}
|
||||||
|
for quesId in quesIds:
|
||||||
|
gts[quesId] = self.vqa.qa[quesId]
|
||||||
|
res[quesId] = self.vqaRes.qa[quesId]
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# Compute accuracy
|
||||||
|
# =================================================
|
||||||
|
accQA = []
|
||||||
|
accQuesType = {}
|
||||||
|
accAnsType = {}
|
||||||
|
print("computing accuracy")
|
||||||
|
step = 0
|
||||||
|
for quesId in quesIds:
|
||||||
|
resAns = res[quesId]["answer"]
|
||||||
|
resAns = resAns.replace("\n", " ")
|
||||||
|
resAns = resAns.replace("\t", " ")
|
||||||
|
resAns = resAns.strip()
|
||||||
|
resAns = self.processPunctuation(resAns)
|
||||||
|
resAns = self.processDigitArticle(resAns)
|
||||||
|
gtAcc = []
|
||||||
|
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
||||||
|
if len(set(gtAnswers)) > 1:
|
||||||
|
for ansDic in gts[quesId]["answers"]:
|
||||||
|
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
||||||
|
for gtAnsDatum in gts[quesId]["answers"]:
|
||||||
|
otherGTAns = [
|
||||||
|
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
||||||
|
]
|
||||||
|
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
||||||
|
acc = min(1, float(len(matchingAns)) / 3)
|
||||||
|
gtAcc.append(acc)
|
||||||
|
quesType = gts[quesId]["question_type"]
|
||||||
|
ansType = gts[quesId]["answer_type"]
|
||||||
|
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
||||||
|
accQA.append(avgGTAcc)
|
||||||
|
if quesType not in accQuesType:
|
||||||
|
accQuesType[quesType] = []
|
||||||
|
accQuesType[quesType].append(avgGTAcc)
|
||||||
|
if ansType not in accAnsType:
|
||||||
|
accAnsType[ansType] = []
|
||||||
|
accAnsType[ansType].append(avgGTAcc)
|
||||||
|
self.setEvalQA(quesId, avgGTAcc)
|
||||||
|
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
||||||
|
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
||||||
|
if step % 100 == 0:
|
||||||
|
self.updateProgress(step / float(len(quesIds)))
|
||||||
|
step = step + 1
|
||||||
|
|
||||||
|
self.setAccuracy(accQA, accQuesType, accAnsType)
|
||||||
|
print("Done computing accuracy")
|
||||||
|
|
||||||
|
def processPunctuation(self, inText):
|
||||||
|
outText = inText
|
||||||
|
for p in self.punct:
|
||||||
|
if (p + " " in inText or " " + p in inText) or (
|
||||||
|
re.search(self.commaStrip, inText) != None
|
||||||
|
):
|
||||||
|
outText = outText.replace(p, "")
|
||||||
|
else:
|
||||||
|
outText = outText.replace(p, " ")
|
||||||
|
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
||||||
|
return outText
|
||||||
|
|
||||||
|
def processDigitArticle(self, inText):
|
||||||
|
outText = []
|
||||||
|
tempText = inText.lower().split()
|
||||||
|
for word in tempText:
|
||||||
|
word = self.manualMap.setdefault(word, word)
|
||||||
|
if word not in self.articles:
|
||||||
|
outText.append(word)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
for wordId, word in enumerate(outText):
|
||||||
|
if word in self.contractions:
|
||||||
|
outText[wordId] = self.contractions[word]
|
||||||
|
outText = " ".join(outText)
|
||||||
|
return outText
|
||||||
|
|
||||||
|
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
||||||
|
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
||||||
|
self.accuracy["perQuestionType"] = {
|
||||||
|
quesType: round(
|
||||||
|
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
||||||
|
self.n,
|
||||||
|
)
|
||||||
|
for quesType in accQuesType
|
||||||
|
}
|
||||||
|
self.accuracy["perAnswerType"] = {
|
||||||
|
ansType: round(
|
||||||
|
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
||||||
|
)
|
||||||
|
for ansType in accAnsType
|
||||||
|
}
|
||||||
|
|
||||||
|
def setEvalQA(self, quesId, acc):
|
||||||
|
self.evalQA[quesId] = round(100 * acc, self.n)
|
||||||
|
|
||||||
|
def setEvalQuesType(self, quesId, quesType, acc):
|
||||||
|
if quesType not in self.evalQuesType:
|
||||||
|
self.evalQuesType[quesType] = {}
|
||||||
|
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
||||||
|
|
||||||
|
def setEvalAnsType(self, quesId, ansType, acc):
|
||||||
|
if ansType not in self.evalAnsType:
|
||||||
|
self.evalAnsType[ansType] = {}
|
||||||
|
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
||||||
|
|
||||||
|
def updateProgress(self, progress):
|
||||||
|
barLength = 20
|
||||||
|
status = ""
|
||||||
|
if isinstance(progress, int):
|
||||||
|
progress = float(progress)
|
||||||
|
if not isinstance(progress, float):
|
||||||
|
progress = 0
|
||||||
|
status = "error: progress var must be float\r\n"
|
||||||
|
if progress < 0:
|
||||||
|
progress = 0
|
||||||
|
status = "Halt...\r\n"
|
||||||
|
if progress >= 1:
|
||||||
|
progress = 1
|
||||||
|
status = "Done...\r\n"
|
||||||
|
block = int(round(barLength * progress))
|
||||||
|
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
||||||
|
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
||||||
|
)
|
||||||
|
sys.stdout.write(text)
|
||||||
|
sys.stdout.flush()
|
654
models/criteria.py
Normal file
654
models/criteria.py
Normal file
|
@ -0,0 +1,654 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from models.utils import allgather_wgrad
|
||||||
|
from utils.dist import get_rank, get_world_size
|
||||||
|
from utils.easydict import EasyDict
|
||||||
|
|
||||||
|
|
||||||
|
def get_sim(
|
||||||
|
x_proj: torch.Tensor,
|
||||||
|
y_proj: torch.Tensor,
|
||||||
|
temp=1.0,
|
||||||
|
):
|
||||||
|
"""calculate pair-wise similarity between two modalities x and y.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_proj (torch.Tensor): The representation of modality x. Shape: [B,T,C] or [B,C].
|
||||||
|
y_proj (torch.Tensor): The representation of modality y. Shape: [B,C].
|
||||||
|
temp (torch.Tensor): The temperature. Shape: [].
|
||||||
|
|
||||||
|
Returns: The similarity between modality x and y. Shape: [B,B].
|
||||||
|
|
||||||
|
"""
|
||||||
|
x_proj = F.normalize(x_proj, dim=-1)
|
||||||
|
y_proj = F.normalize(y_proj, dim=-1)
|
||||||
|
assert x_proj.dim() in [2, 3]
|
||||||
|
assert y_proj.dim() == 2
|
||||||
|
if x_proj.dim() == 2:
|
||||||
|
sim_x2y = torch.einsum("md,nd->mn", x_proj, y_proj) / temp # (B,B)
|
||||||
|
else:
|
||||||
|
sim_x2y = torch.einsum("mld,nd->mln", x_proj, y_proj).mean(1) / temp # (B,B)
|
||||||
|
sim_y2x = sim_x2y.T
|
||||||
|
return sim_x2y, sim_y2x
|
||||||
|
|
||||||
|
|
||||||
|
class ContMatchLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ContMatchLoss, self).__init__()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_mask(self, sim, idx=None, normalize=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sim (torch.Tensor): The similarity between videos and texts. shape: (B, B).
|
||||||
|
idx (torch.Tensor): The index for each video. Shape: [B].
|
||||||
|
normalize (bool): If true, make row sum equal to 1
|
||||||
|
"""
|
||||||
|
if idx is not None:
|
||||||
|
idx = idx.view(-1, 1)
|
||||||
|
mask = torch.eq(idx, idx.T).to(sim.dtype)
|
||||||
|
if normalize:
|
||||||
|
mask = mask / mask.sum(1, keepdim=True)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros_like(sim)
|
||||||
|
mask.fill_diagonal_(1)
|
||||||
|
return mask # `1` mark valid/matched location
|
||||||
|
|
||||||
|
@lru_cache(maxsize=16)
|
||||||
|
def get_gather_args(self):
|
||||||
|
"""obtain the args for all_gather
|
||||||
|
Returns: dict.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return EasyDict({"world_size": get_world_size(), "rank": get_rank()})
|
||||||
|
|
||||||
|
|
||||||
|
class STC_STM_Loss(ContMatchLoss):
|
||||||
|
"""Contrastive and matching losses"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(STC_STM_Loss, self).__init__()
|
||||||
|
|
||||||
|
def stc_loss(
|
||||||
|
self,
|
||||||
|
temporal_proj: torch.Tensor,
|
||||||
|
spatial_proj: torch.Tensor,
|
||||||
|
idx: torch.Tensor,
|
||||||
|
temp=1.0,
|
||||||
|
all_gather=True
|
||||||
|
):
|
||||||
|
|
||||||
|
"""forward to calculate the loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
||||||
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
||||||
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
||||||
|
temp (torch.Tensor): The temperature. Shape: [].
|
||||||
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
||||||
|
|
||||||
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
||||||
|
|
||||||
|
"""
|
||||||
|
if all_gather:
|
||||||
|
gather_args = self.get_gather_args()
|
||||||
|
temporal_proj = allgather_wgrad(temporal_proj, gather_args)
|
||||||
|
spatial_proj = allgather_wgrad(spatial_proj, gather_args)
|
||||||
|
if idx is not None:
|
||||||
|
idx = allgather_wgrad(idx, gather_args)
|
||||||
|
|
||||||
|
sim_t2s, sim_s2t = get_sim(temporal_proj, spatial_proj, temp)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_t2s_targets = self.get_mask(sim_t2s, idx=idx, normalize=True)
|
||||||
|
sim_s2t_targets = sim_t2s_targets
|
||||||
|
|
||||||
|
loss_t2s = -torch.sum(F.log_softmax(sim_t2s, dim=1) * sim_t2s_targets, dim=1).mean()
|
||||||
|
loss_s2t = -torch.sum(F.log_softmax(sim_s2t, dim=1) * sim_s2t_targets, dim=1).mean()
|
||||||
|
|
||||||
|
loss_stc = (loss_t2s + loss_s2t) / 2
|
||||||
|
return loss_stc
|
||||||
|
|
||||||
|
def stm_loss(
|
||||||
|
self,
|
||||||
|
grounding_expert,
|
||||||
|
stm_head,
|
||||||
|
# temp,
|
||||||
|
spatial_embeds_orig,
|
||||||
|
temporal_embeds_orig,
|
||||||
|
temporal_proj,
|
||||||
|
spatial_proj,
|
||||||
|
idx,
|
||||||
|
generation=False,
|
||||||
|
temp=1.0
|
||||||
|
):
|
||||||
|
spatial_embeds = spatial_embeds_orig.clone()
|
||||||
|
temporal_embeds = temporal_embeds_orig.clone()
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_s2t, sim_t2s = get_sim(temporal_proj, spatial_proj, temp)
|
||||||
|
spatial_atts = torch.ones(
|
||||||
|
spatial_embeds.size()[:-1], dtype=torch.long, device=spatial_embeds.device
|
||||||
|
)
|
||||||
|
temporal_atts = torch.ones(
|
||||||
|
temporal_embeds.size()[:-1], dtype=torch.long, device=temporal_embeds.device
|
||||||
|
)
|
||||||
|
weights_s2t = F.softmax(sim_s2t + 1e-4, dim=1) # (N, N)
|
||||||
|
weights_t2s = F.softmax(sim_t2s + 1e-4, dim=1)
|
||||||
|
|
||||||
|
mask = self.get_mask(sim_s2t, idx=idx).bool()
|
||||||
|
weights_s2t.masked_fill_(mask, 0)
|
||||||
|
weights_t2s.masked_fill_(mask, 0)
|
||||||
|
weights_s2t = torch.nan_to_num_(weights_s2t, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
weights_t2s = torch.nan_to_num_(weights_t2s, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = grounding_expert(
|
||||||
|
encoder_embeds=temporal_embeds,
|
||||||
|
attention_mask=temporal_atts,
|
||||||
|
encoder_hidden_states=spatial_embeds,
|
||||||
|
encoder_attention_mask=spatial_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
pos_feats = output.last_hidden_state
|
||||||
|
return pos_feats
|
||||||
|
|
||||||
|
else:
|
||||||
|
# select a hard negatives within the batch
|
||||||
|
spatial_neg_indices = torch.multinomial(weights_s2t, 1).squeeze()
|
||||||
|
temporal_neg_indices = torch.multinomial(weights_t2s, 1).squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
spatial_embeds_neg = spatial_embeds[spatial_neg_indices] # [B, L, c]
|
||||||
|
temporal_embeds_neg = temporal_embeds[temporal_neg_indices] # [B, L, d]
|
||||||
|
# temporal_atts_neg = temporal_atts[temporal_neg_indices]
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
spatial_embeds_all = torch.cat([spatial_embeds, spatial_embeds_neg, spatial_embeds], dim=0)
|
||||||
|
temporal_embeds_all = torch.cat([temporal_embeds, temporal_embeds, temporal_embeds_neg], dim=0)
|
||||||
|
spatial_atts_all = torch.cat([spatial_atts, spatial_atts, spatial_atts], dim=0)
|
||||||
|
temporal_atts_all = torch.cat([temporal_atts, temporal_atts, temporal_atts], dim=0)
|
||||||
|
|
||||||
|
output = grounding_expert(
|
||||||
|
inputs_embeds=temporal_embeds_all,
|
||||||
|
attention_mask=temporal_atts_all,
|
||||||
|
cross_embeds=spatial_embeds_all,
|
||||||
|
cross_attention_mask=spatial_atts_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
stm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
||||||
|
|
||||||
|
stm_logits = stm_head(stm_embeds) # [3*B, 2]
|
||||||
|
|
||||||
|
bs = stm_logits.shape[0] // 3
|
||||||
|
stm_labels = stm_logits.new_ones(3 * bs, dtype=torch.long)
|
||||||
|
stm_labels[bs:] = 0
|
||||||
|
loss_stm = F.cross_entropy(stm_logits, stm_labels)
|
||||||
|
pos_feats = output.last_hidden_state[:bs]
|
||||||
|
|
||||||
|
return loss_stm, pos_feats
|
||||||
|
|
||||||
|
|
||||||
|
class VCC_VCM_Loss(ContMatchLoss):
|
||||||
|
"""Contrastive and matching losses"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(VCC_VCM_Loss, self).__init__()
|
||||||
|
|
||||||
|
def vcc_loss(
|
||||||
|
self,
|
||||||
|
vis_proj: torch.Tensor,
|
||||||
|
cap_proj: torch.Tensor,
|
||||||
|
idx: torch.Tensor,
|
||||||
|
temp=1.0,
|
||||||
|
all_gather=True
|
||||||
|
):
|
||||||
|
|
||||||
|
"""forward to calculate the loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
||||||
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
||||||
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
||||||
|
temp (torch.Tensor): The temperature. Shape: [].
|
||||||
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
||||||
|
|
||||||
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
||||||
|
|
||||||
|
"""
|
||||||
|
if all_gather:
|
||||||
|
gather_args = self.get_gather_args()
|
||||||
|
vis_proj = allgather_wgrad(vis_proj, gather_args)
|
||||||
|
cap_proj = allgather_wgrad(cap_proj, gather_args)
|
||||||
|
if idx is not None:
|
||||||
|
idx = allgather_wgrad(idx, gather_args)
|
||||||
|
|
||||||
|
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_v2c_targets = self.get_mask(sim_v2c, idx=idx, normalize=True)
|
||||||
|
sim_c2v_targets = sim_v2c_targets
|
||||||
|
|
||||||
|
loss_v2c = -torch.sum(F.log_softmax(sim_v2c, dim=1) * sim_v2c_targets, dim=1).mean()
|
||||||
|
loss_c2v = -torch.sum(F.log_softmax(sim_c2v, dim=1) * sim_c2v_targets, dim=1).mean()
|
||||||
|
|
||||||
|
loss_vcc = (loss_v2c + loss_c2v) / 2
|
||||||
|
return loss_vcc
|
||||||
|
|
||||||
|
def vcm_loss(
|
||||||
|
self,
|
||||||
|
grounding_expert,
|
||||||
|
vcm_head,
|
||||||
|
vis_embeds_orig,
|
||||||
|
cap_embeds_orig,
|
||||||
|
vis_proj,
|
||||||
|
cap_proj,
|
||||||
|
cap_atts,
|
||||||
|
idx,
|
||||||
|
generation=False,
|
||||||
|
temp=1.0
|
||||||
|
):
|
||||||
|
vis_embeds = vis_embeds_orig.clone()
|
||||||
|
cap_embeds = cap_embeds_orig.clone()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
|
||||||
|
vis_atts = torch.ones(
|
||||||
|
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights_v2c = F.softmax(sim_v2c + 1e-4, dim=1) # (N, N)
|
||||||
|
weights_c2v = F.softmax(sim_c2v + 1e-4, dim=1)
|
||||||
|
|
||||||
|
mask = self.get_mask(weights_v2c, idx=idx).bool()
|
||||||
|
weights_v2c.masked_fill_(mask, 0)
|
||||||
|
weights_c2v.masked_fill_(mask, 0)
|
||||||
|
weights_v2c = torch.nan_to_num_(weights_v2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
weights_c2v = torch.nan_to_num_(weights_c2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = grounding_expert(
|
||||||
|
encoder_embeds=cap_embeds,
|
||||||
|
attention_mask=cap_atts,
|
||||||
|
encoder_hidden_states=vis_embeds,
|
||||||
|
encoder_attention_mask=vis_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
pos_feats = output.last_hidden_state
|
||||||
|
return pos_feats
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
|
||||||
|
# select a hard negatives within the batch
|
||||||
|
vis_neg_indices = torch.multinomial(weights_v2c, 1).squeeze()
|
||||||
|
cap_neg_indices = torch.multinomial(weights_c2v, 1).squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
|
||||||
|
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, d]
|
||||||
|
cap_atts_neg = cap_atts[cap_neg_indices]
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
|
||||||
|
cap_embeds_all = torch.cat([cap_embeds, cap_embeds, cap_embeds_neg], dim=0)
|
||||||
|
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
|
||||||
|
cap_atts_all = torch.cat([cap_atts, cap_atts, cap_atts_neg], dim=0)
|
||||||
|
|
||||||
|
output = grounding_expert(
|
||||||
|
inputs_embeds=cap_embeds_all,
|
||||||
|
attention_mask=cap_atts_all,
|
||||||
|
cross_embeds=vis_embeds_all,
|
||||||
|
cross_attention_mask=vis_atts_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
vcm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
||||||
|
|
||||||
|
vcm_logits = vcm_head(vcm_embeds) # [3*B, 2]
|
||||||
|
|
||||||
|
bs = vcm_logits.shape[0] // 3
|
||||||
|
vcm_labels = vcm_logits.new_ones(3 * bs, dtype=torch.long)
|
||||||
|
vcm_labels[bs:] = 0
|
||||||
|
loss_vcm = F.cross_entropy(vcm_logits, vcm_labels)
|
||||||
|
pos_feats = output.last_hidden_state[:bs]
|
||||||
|
return loss_vcm, pos_feats
|
||||||
|
|
||||||
|
|
||||||
|
class VHC_VHM_Loss(ContMatchLoss):
|
||||||
|
"""Contrastive and matching losses"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(VHC_VHM_Loss, self).__init__()
|
||||||
|
|
||||||
|
def vhc_loss(
|
||||||
|
self,
|
||||||
|
vis_proj: torch.Tensor,
|
||||||
|
hist_proj: torch.Tensor,
|
||||||
|
idx: torch.Tensor,
|
||||||
|
temp=1.0,
|
||||||
|
all_gather=True
|
||||||
|
):
|
||||||
|
|
||||||
|
"""forward to calculate the loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
||||||
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
||||||
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
||||||
|
temp (torch.Tensor): The temperature. Shape: [].
|
||||||
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
||||||
|
|
||||||
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
||||||
|
|
||||||
|
"""
|
||||||
|
if all_gather:
|
||||||
|
gather_args = self.get_gather_args()
|
||||||
|
vis_proj = allgather_wgrad(vis_proj, gather_args)
|
||||||
|
hist_proj = allgather_wgrad(hist_proj, gather_args)
|
||||||
|
if idx is not None:
|
||||||
|
idx = allgather_wgrad(idx, gather_args)
|
||||||
|
|
||||||
|
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_v2h_targets = self.get_mask(sim_v2h, idx=idx, normalize=True)
|
||||||
|
sim_h2v_targets = sim_v2h_targets
|
||||||
|
|
||||||
|
loss_v2h = -torch.sum(F.log_softmax(sim_v2h, dim=1) * sim_v2h_targets, dim=1).mean()
|
||||||
|
loss_h2v = -torch.sum(F.log_softmax(sim_h2v, dim=1) * sim_h2v_targets, dim=1).mean()
|
||||||
|
|
||||||
|
loss_vhc = (loss_v2h + loss_h2v) / 2
|
||||||
|
return loss_vhc
|
||||||
|
|
||||||
|
def vhm_loss(
|
||||||
|
self,
|
||||||
|
grounding_expert,
|
||||||
|
vhm_head,
|
||||||
|
vis_embeds_orig,
|
||||||
|
hist_embeds_orig,
|
||||||
|
vis_proj,
|
||||||
|
hist_proj,
|
||||||
|
hist_atts,
|
||||||
|
idx,
|
||||||
|
generation=False,
|
||||||
|
temp=1.0,
|
||||||
|
):
|
||||||
|
vis_embeds = vis_embeds_orig.clone()
|
||||||
|
hist_embeds = hist_embeds_orig.clone()
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
|
||||||
|
vis_atts = torch.ones(
|
||||||
|
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights_v2h = F.softmax(sim_v2h + 1e-4, dim=1) # (N, N)
|
||||||
|
weights_h2v = F.softmax(sim_h2v + 1e-4, dim=1)
|
||||||
|
|
||||||
|
mask = self.get_mask(weights_v2h, idx=idx).bool()
|
||||||
|
weights_v2h.masked_fill_(mask, 0)
|
||||||
|
weights_h2v.masked_fill_(mask, 0)
|
||||||
|
weights_v2h = torch.nan_to_num_(weights_v2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
weights_h2v = torch.nan_to_num_(weights_h2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = grounding_expert(
|
||||||
|
encoder_embeds=hist_embeds,
|
||||||
|
attention_mask=hist_atts,
|
||||||
|
encoder_hidden_states=vis_embeds,
|
||||||
|
encoder_attention_mask=vis_atts,
|
||||||
|
return_dict=True,
|
||||||
|
# mode="fusion",
|
||||||
|
)
|
||||||
|
pos_feats = output.last_hidden_state
|
||||||
|
return pos_feats
|
||||||
|
|
||||||
|
else:
|
||||||
|
# select a hard negatives within the batch
|
||||||
|
vis_neg_indices = torch.multinomial(weights_v2h, 1).squeeze()
|
||||||
|
hist_neg_indices = torch.multinomial(weights_h2v, 1).squeeze()
|
||||||
|
|
||||||
|
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
|
||||||
|
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
|
||||||
|
hist_atts_neg = hist_atts[hist_neg_indices]
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
|
||||||
|
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
|
||||||
|
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
|
||||||
|
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
|
||||||
|
|
||||||
|
output = grounding_expert(
|
||||||
|
inputs_embeds=hist_embeds_all,
|
||||||
|
attention_mask=hist_atts_all,
|
||||||
|
cross_embeds=vis_embeds_all,
|
||||||
|
cross_attention_mask=vis_atts_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
vhm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
||||||
|
|
||||||
|
vhm_logits = vhm_head(vhm_embeds) # [3*B, 2]
|
||||||
|
|
||||||
|
bs = vhm_logits.shape[0] // 3
|
||||||
|
vhm_labels = vhm_logits.new_ones(3 * bs, dtype=torch.long)
|
||||||
|
vhm_labels[bs:] = 0
|
||||||
|
loss_vhm = F.cross_entropy(vhm_logits, vhm_labels)
|
||||||
|
pos_feats = output.last_hidden_state[:bs]
|
||||||
|
|
||||||
|
return loss_vhm, pos_feats
|
||||||
|
|
||||||
|
|
||||||
|
class CHC_CHM_Loss(ContMatchLoss):
|
||||||
|
"""Contrastive and matching losses"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CHC_CHM_Loss, self).__init__()
|
||||||
|
|
||||||
|
def chc_loss(
|
||||||
|
self,
|
||||||
|
cap_proj: torch.Tensor,
|
||||||
|
hist_proj: torch.Tensor,
|
||||||
|
idx: torch.Tensor,
|
||||||
|
temp=1.0,
|
||||||
|
all_gather=True
|
||||||
|
):
|
||||||
|
|
||||||
|
"""forward to calculate the loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
|
||||||
|
text_proj (torch.Tensor): The text representation. Shape: [B,C].
|
||||||
|
idx (torch.Tensor): The index for each example. Shape: [B,].
|
||||||
|
temp (torch.Tensor): The temperature. Shape: [].
|
||||||
|
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
|
||||||
|
|
||||||
|
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
|
||||||
|
|
||||||
|
"""
|
||||||
|
if all_gather:
|
||||||
|
gather_args = self.get_gather_args()
|
||||||
|
cap_proj = allgather_wgrad(cap_proj, gather_args)
|
||||||
|
hist_proj = allgather_wgrad(hist_proj, gather_args)
|
||||||
|
if idx is not None:
|
||||||
|
idx = allgather_wgrad(idx, gather_args)
|
||||||
|
|
||||||
|
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_c2h_targets = self.get_mask(sim_c2h, idx=idx, normalize=True)
|
||||||
|
sim_h2c_targets = sim_c2h_targets
|
||||||
|
|
||||||
|
loss_c2h = -torch.sum(F.log_softmax(sim_c2h, dim=1) * sim_c2h_targets, dim=1).mean()
|
||||||
|
loss_h2c = -torch.sum(F.log_softmax(sim_h2c, dim=1) * sim_h2c_targets, dim=1).mean()
|
||||||
|
|
||||||
|
loss_chc = (loss_c2h + loss_h2c) / 2
|
||||||
|
return loss_chc
|
||||||
|
|
||||||
|
def chm_loss(
|
||||||
|
self,
|
||||||
|
grounding_expert,
|
||||||
|
chm_head,
|
||||||
|
cap_embeds_orig,
|
||||||
|
hist_embeds_orig,
|
||||||
|
cap_proj,
|
||||||
|
hist_proj,
|
||||||
|
cap_atts,
|
||||||
|
hist_atts,
|
||||||
|
idx,
|
||||||
|
generation=False,
|
||||||
|
temp=1.0
|
||||||
|
):
|
||||||
|
cap_embeds = cap_embeds_orig.clone()
|
||||||
|
hist_embeds = hist_embeds_orig.clone()
|
||||||
|
with torch.no_grad():
|
||||||
|
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
|
||||||
|
|
||||||
|
weights_c2h = F.softmax(sim_c2h + 1e-4, dim=1) # (N, N)
|
||||||
|
weights_h2c = F.softmax(sim_h2c + 1e-4, dim=1)
|
||||||
|
|
||||||
|
mask = self.get_mask(weights_c2h, idx=idx).bool()
|
||||||
|
weights_c2h.masked_fill_(mask, 0)
|
||||||
|
weights_h2c.masked_fill_(mask, 0)
|
||||||
|
weights_c2h = torch.nan_to_num_(weights_c2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
weights_h2c = torch.nan_to_num_(weights_h2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = grounding_expert(
|
||||||
|
encoder_embeds=hist_embeds,
|
||||||
|
attention_mask=hist_atts,
|
||||||
|
encoder_hidden_states=cap_embeds,
|
||||||
|
encoder_attention_mask=cap_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
pos_feats = output.last_hidden_state
|
||||||
|
return pos_feats
|
||||||
|
else:
|
||||||
|
# select a hard negatives within the batch
|
||||||
|
cap_neg_indices = torch.multinomial(weights_c2h, 1).squeeze()
|
||||||
|
hist_neg_indices = torch.multinomial(weights_h2c, 1).squeeze()
|
||||||
|
|
||||||
|
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, c]
|
||||||
|
cap_atts_neg = cap_atts[cap_neg_indices]
|
||||||
|
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
|
||||||
|
hist_atts_neg = hist_atts[hist_neg_indices]
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
cap_embeds_all = torch.cat([cap_embeds, cap_embeds_neg, cap_embeds], dim=0)
|
||||||
|
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
|
||||||
|
cap_atts_all = torch.cat([cap_atts, cap_atts_neg, cap_atts], dim=0)
|
||||||
|
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
|
||||||
|
|
||||||
|
output = grounding_expert(
|
||||||
|
inputs_embeds=hist_embeds_all,
|
||||||
|
attention_mask=hist_atts_all,
|
||||||
|
cross_embeds=cap_embeds_all,
|
||||||
|
cross_attention_mask=cap_atts_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
chm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
||||||
|
|
||||||
|
chm_logits = chm_head(chm_embeds) # [3*B, 2]
|
||||||
|
|
||||||
|
bs = chm_logits.shape[0] // 3
|
||||||
|
chm_labels = chm_logits.new_ones(3 * bs, dtype=torch.long)
|
||||||
|
chm_labels[bs:] = 0
|
||||||
|
loss_chm = F.cross_entropy(chm_logits, chm_labels)
|
||||||
|
pos_feats = output.last_hidden_state[:bs]
|
||||||
|
return loss_chm, pos_feats
|
||||||
|
|
||||||
|
|
||||||
|
class MLMLoss(nn.Module):
|
||||||
|
"""masked language modeling loss."""
|
||||||
|
|
||||||
|
def __init__(self, masking_prob, tokenizer):
|
||||||
|
super(MLMLoss, self).__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.masking_prob = masking_prob
|
||||||
|
|
||||||
|
def mlm_loss(
|
||||||
|
self,
|
||||||
|
text_encoder,
|
||||||
|
text,
|
||||||
|
text_embeds,
|
||||||
|
vision_embeds,
|
||||||
|
vision_atts,
|
||||||
|
):
|
||||||
|
input_ids = text.input_ids.clone()
|
||||||
|
labels = input_ids.clone()
|
||||||
|
probability_matrix = torch.full(labels.shape, self.masking_prob)
|
||||||
|
input_ids, labels = self.mask(
|
||||||
|
input_ids,
|
||||||
|
text_encoder.config.vocab_size,
|
||||||
|
input_ids.device,
|
||||||
|
targets=labels,
|
||||||
|
probability_matrix=probability_matrix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# intermediate_mlm_output = text_encoder.bert(
|
||||||
|
# input_ids,
|
||||||
|
# attention_mask=text.attention_mask,
|
||||||
|
# encoder_hidden_states=vision_embeds,
|
||||||
|
# encoder_attention_mask=vision_atts,
|
||||||
|
# return_dict=True,
|
||||||
|
# # mode="text",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# text_embeds = intermediate_mlm_output.last_hidden_state
|
||||||
|
|
||||||
|
mlm_output = text_encoder(
|
||||||
|
encoder_embeds=text_embeds,
|
||||||
|
attention_mask=text.attention_mask,
|
||||||
|
encoder_hidden_states=vision_embeds,
|
||||||
|
encoder_attention_mask=vision_atts,
|
||||||
|
return_dict=True,
|
||||||
|
labels=labels,
|
||||||
|
soft_labels=None,
|
||||||
|
# mode="fusion",
|
||||||
|
)
|
||||||
|
return mlm_output.loss
|
||||||
|
|
||||||
|
def mask(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
vocab_size,
|
||||||
|
device,
|
||||||
|
targets=None,
|
||||||
|
masked_indices=None,
|
||||||
|
probability_matrix=None,
|
||||||
|
):
|
||||||
|
if masked_indices is None:
|
||||||
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
|
|
||||||
|
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
||||||
|
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
# We only compute loss on masked tokens
|
||||||
|
targets[~masked_indices] = -100
|
||||||
|
|
||||||
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
|
indices_replaced = (
|
||||||
|
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
||||||
|
)
|
||||||
|
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
||||||
|
|
||||||
|
# 10% of the time, we replace masked input tokens with random word
|
||||||
|
indices_random = (
|
||||||
|
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
|
||||||
|
& masked_indices
|
||||||
|
& ~indices_replaced
|
||||||
|
)
|
||||||
|
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
||||||
|
input_ids[indices_random] = random_words[indices_random]
|
||||||
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
return input_ids, targets
|
||||||
|
else:
|
||||||
|
return input_ids
|
0
models/modules/__init__.py
Normal file
0
models/modules/__init__.py
Normal file
286
models/modules/temporal_modelling.py
Normal file
286
models/modules/temporal_modelling.py
Normal file
|
@ -0,0 +1,286 @@
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from timm.models.layers.drop import DropPath
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import LayerNorm, Linear, MultiheadAttention
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class STAdapter(nn.Module):
|
||||||
|
"""ST Adapter"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
input_dim=768,
|
||||||
|
hidden_dim=384,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
drop_prob=0.1,
|
||||||
|
):
|
||||||
|
super(STAdapter, self).__init__()
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.h = self.w = img_size // patch_size
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.linear2 = nn.Linear(hidden_dim, input_dim)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
self.conv = nn.Conv3d(
|
||||||
|
hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim
|
||||||
|
)
|
||||||
|
self.droppath = DropPath(drop_prob=drop_prob)
|
||||||
|
|
||||||
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
||||||
|
|
||||||
|
Returns: features after adapter. The same shape as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if x.shape[1] == 1: # for single frame, return itself.
|
||||||
|
return x
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.linear1(x)
|
||||||
|
cls = x[:, :, :1, :]
|
||||||
|
tokens = x[:, :, 1:, :]
|
||||||
|
tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous()
|
||||||
|
tokens = self.conv(tokens)
|
||||||
|
tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c")
|
||||||
|
x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c]
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
|
||||||
|
return shortcut + self.scale * self.droppath(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
"""Perfrom spatial self-attention"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
||||||
|
super(SpatialAttention, self).__init__()
|
||||||
|
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
|
||||||
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
||||||
|
self.linear = Linear(input_dim, input_dim)
|
||||||
|
self.droppath = DropPath(droppath_rate)
|
||||||
|
# self.scale = nn.parameter.Parameter(torch.zeros([]))
|
||||||
|
self.scale = 1.0
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if x.shape[1] == 1:
|
||||||
|
x = self.norm(x)
|
||||||
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
||||||
|
return x # return self if media is image
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = einops.rearrange(x, 'b t l c -> (b t) l c')
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.attn(x, x, x)[0]
|
||||||
|
x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0])
|
||||||
|
x = shortcut + self.scale * self.droppath(x)
|
||||||
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalAttention(nn.Module):
|
||||||
|
|
||||||
|
"""perform temporal self-attention"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
input_dim (int): The input feature dimension.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(TemporalAttention, self).__init__()
|
||||||
|
|
||||||
|
self._input_dim = input_dim
|
||||||
|
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
|
||||||
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
||||||
|
self.linear = Linear(input_dim, input_dim)
|
||||||
|
self.droppath = DropPath(droppath_rate)
|
||||||
|
# self.scale = nn.parameter.Parameter(torch.zeros([]))
|
||||||
|
self.scale = 1.0
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
||||||
|
|
||||||
|
Returns: features after adapter. The same shape as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if x.shape[1] == 1: # for single frame, return itself.
|
||||||
|
x = self.norm(x)
|
||||||
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
||||||
|
return x
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = einops.rearrange(x, "b t l c -> (b l) t c")
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.attn(x, x, x)[0]
|
||||||
|
x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0])
|
||||||
|
x = shortcut + self.scale * self.droppath(x)
|
||||||
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowTemporalAttention(nn.Module):
|
||||||
|
|
||||||
|
"""perform windowed temporal self-attention"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
input_dim (int): The input feature dimension.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._input_dim = input_dim
|
||||||
|
self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64)
|
||||||
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
||||||
|
self.droppath = DropPath(droppath_rate)
|
||||||
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
||||||
|
self.wh, self.ww = window_size
|
||||||
|
# logger.info(f"WindowTemporalAttention: window_size: {window_size}")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
||||||
|
|
||||||
|
Returns: features after adapter. The same shape as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if x.shape[1] == 1: # for single frame, return itself.
|
||||||
|
return x
|
||||||
|
shortcut = x
|
||||||
|
|
||||||
|
h = w = int(math.sqrt(x.shape[2] - 1))
|
||||||
|
cls_token = x[:, :, :1, :]
|
||||||
|
x = einops.rearrange(
|
||||||
|
x[:, :, 1:, :],
|
||||||
|
"b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c",
|
||||||
|
nh=h // self.wh,
|
||||||
|
wh=self.wh,
|
||||||
|
nw=w // self.ww,
|
||||||
|
ww=self.ww,
|
||||||
|
)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.temporal_attn(x, x, x)[0]
|
||||||
|
x = einops.rearrange(
|
||||||
|
x,
|
||||||
|
"(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c",
|
||||||
|
wh=self.wh,
|
||||||
|
ww=self.ww,
|
||||||
|
nh=h // self.wh,
|
||||||
|
nw=w // self.ww,
|
||||||
|
)
|
||||||
|
# add back cls token.
|
||||||
|
x = torch.concat([cls_token, x], dim=2)
|
||||||
|
return shortcut + self.scale * self.droppath(x)
|
||||||
|
|
||||||
|
|
||||||
|
class X_CLIP(nn.Module):
|
||||||
|
|
||||||
|
"""perform windowed temporal self-attention"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
input_dim (int): The input feature dimension.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
d_model = input_dim
|
||||||
|
|
||||||
|
self.message_fc = nn.Linear(d_model, d_model)
|
||||||
|
self.message_ln = LayerNorm(d_model, eps=1e-12)
|
||||||
|
self.message_attn = nn.MultiheadAttention(d_model, d_model // 64)
|
||||||
|
self.num_prompts = num_prompts
|
||||||
|
|
||||||
|
self.droppath = DropPath(droppath_rate)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
||||||
|
|
||||||
|
Returns: features after adapter. The same shape as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if x.shape[1] == 1: # for single frame, return itself.
|
||||||
|
return x
|
||||||
|
msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c]
|
||||||
|
msg_token = rearrange(msg_token, "b t c -> t b c")
|
||||||
|
msg_token = msg_token + self.droppath(
|
||||||
|
self.message_attn(msg_token, msg_token, msg_token)[0]
|
||||||
|
)
|
||||||
|
msg_token = rearrange(msg_token, "t b c -> b t c")
|
||||||
|
# replace the last prompt token with msg_token.
|
||||||
|
x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalS4(nn.Module):
|
||||||
|
|
||||||
|
"""perform temporal self-attention"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
input_dim (int): The input feature dimension.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
from .s4 import S4
|
||||||
|
|
||||||
|
self._input_dim = input_dim
|
||||||
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
||||||
|
self.droppath = DropPath(droppath_rate)
|
||||||
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
||||||
|
self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
||||||
|
|
||||||
|
Returns: features after adapter. The same shape as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if x.shape[1] == 1: # for single frame, return itself.
|
||||||
|
return x
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm(x)
|
||||||
|
x = einops.rearrange(x, "b t l c -> b c (t l)")
|
||||||
|
x, _ = self.s4(x)
|
||||||
|
x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1])
|
||||||
|
return shortcut + self.scale * self.droppath(x)
|
358
models/setup.py
Normal file
358
models/setup.py
Normal file
|
@ -0,0 +1,358 @@
|
||||||
|
import copy
|
||||||
|
import os.path as osp
|
||||||
|
import glog as logger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
from models.backbones.beit.builder import interpolate_pos_embed_beit
|
||||||
|
from models.backbones.bert.tokenization_bert import BertTokenizer
|
||||||
|
from transformers import T5Tokenizer, BartTokenizer, LlamaTokenizer
|
||||||
|
from utils.optimizer import create_optimizer
|
||||||
|
from utils.scheduler import create_scheduler
|
||||||
|
from datasets.dataloader import load_dataloaders
|
||||||
|
from datasets.pretraining import load_datasets as load_datasets_stage_1
|
||||||
|
from datasets.visdial_dataset import load_visdial_dataset
|
||||||
|
from datasets.champagne_dataset import load_champagne_dataset
|
||||||
|
from datasets.nextqa_dataset import load_nextqa_dataset
|
||||||
|
from datasets.avsd_dataset import load_avsd_dataset
|
||||||
|
# from datasets.avsd_dataset_like_mixer import load_avsd_dataset
|
||||||
|
|
||||||
|
from processors.blip_processors import Blip2ImageTrainProcessor
|
||||||
|
from processors.blip_processors import BlipCaptionProcessor, BlipDialogProcessor
|
||||||
|
|
||||||
|
from utils.init import set_training_steps
|
||||||
|
# from models.v2dial import V2Dial, V2DialBase
|
||||||
|
from models.v2dial import V2DialBase, V2Dial, V2DialNoMoes
|
||||||
|
|
||||||
|
# from datasets.avsd_dataset import get_dataset, AVSDDataSet
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(
|
||||||
|
config, has_decoder=False, pretrain=False, find_unused_parameters=True
|
||||||
|
):
|
||||||
|
logger.info("Creating model")
|
||||||
|
|
||||||
|
if config['stage'] == 'stage_1':
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
# model = V2DialBase(config=config, expert_tokenizer=tokenizer)
|
||||||
|
model = V2DialBase(config)
|
||||||
|
model = model.to(torch.device('cuda'))
|
||||||
|
model_without_ddp = model
|
||||||
|
optimizer = create_optimizer(config, model)
|
||||||
|
scheduler = create_scheduler(config, optimizer)
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
|
||||||
|
|
||||||
|
if config['distributed']:
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=[config['gpu']],
|
||||||
|
find_unused_parameters=find_unused_parameters, # `False` for image-only task
|
||||||
|
)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
global_step = 0
|
||||||
|
webvid_step = 0
|
||||||
|
cc3m_step = 0
|
||||||
|
|
||||||
|
if osp.isfile(config['pretrained_path']):
|
||||||
|
logger.info(f"Loading checkpoint from {config['pretrained_path']}")
|
||||||
|
checkpoint = torch.load(config['pretrained_path'], map_location="cpu")
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
|
||||||
|
if config.resume:
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||||
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
|
start_epoch = checkpoint["epoch"] + 1
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
elif not pretrain: # downstream init from pretrained ckpt
|
||||||
|
|
||||||
|
# interpolate positional embeddings.
|
||||||
|
state_dict = interpolate_pos_embed_beit(state_dict, model_without_ddp)
|
||||||
|
|
||||||
|
|
||||||
|
#TODO Might need to update to match the MoEs
|
||||||
|
if not config.evaluate: # finetuning from a pretarined weights.
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "bert" in key:
|
||||||
|
encoder_key = key.replace("bert.", "")
|
||||||
|
state_dict[encoder_key] = state_dict[key]
|
||||||
|
if not has_decoder:
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
|
||||||
|
# only for generation tasks like VQA
|
||||||
|
if has_decoder and "text_encoder" in key:
|
||||||
|
if "layer" in key:
|
||||||
|
encoder_keys = key.split(".")
|
||||||
|
layer_num = int(encoder_keys[4])
|
||||||
|
if layer_num < config.model.text_encoder.fusion_layer:
|
||||||
|
del state_dict[key]
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
decoder_layer_num = layer_num - 9
|
||||||
|
encoder_keys[4] = str(decoder_layer_num)
|
||||||
|
encoder_key = ".".join(encoder_keys)
|
||||||
|
else:
|
||||||
|
encoder_key = key
|
||||||
|
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
|
||||||
|
state_dict[decoder_key] = state_dict[key]
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
|
||||||
|
logger.info(msg)
|
||||||
|
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("No pretrained checkpoint provided, training from scratch")
|
||||||
|
|
||||||
|
return (
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
global_step,
|
||||||
|
webvid_step,
|
||||||
|
cc3m_step,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# config = copy.deepcopy(config)
|
||||||
|
# if config['use_original_feats']:
|
||||||
|
# model = AVSDBart(config)
|
||||||
|
# else:
|
||||||
|
# # model = V2Dial(config, tokenizer_experts, tokenizer_enc_dec)
|
||||||
|
# if config.use_moes:
|
||||||
|
model = V2Dial(config)
|
||||||
|
# else:
|
||||||
|
# model = V2DialNoMoes(config)
|
||||||
|
|
||||||
|
model = model.to(torch.device('cuda'))
|
||||||
|
model_without_ddp = model
|
||||||
|
|
||||||
|
optimizer = None
|
||||||
|
scheduler = None
|
||||||
|
scaler = None
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
global_step = 0
|
||||||
|
if config['stage'] == 'stage_3':
|
||||||
|
visdial_step = 0
|
||||||
|
avsd_step = 0
|
||||||
|
nextqa_step = 0
|
||||||
|
|
||||||
|
ckpt_path = config.pretrained_path_resume if config.resume else config.pretrained_path_prev_stage
|
||||||
|
if config.generating:
|
||||||
|
ckpt_path = config.best_ckpt_path
|
||||||
|
|
||||||
|
if osp.isfile(ckpt_path):
|
||||||
|
logger.info(f"Loading checkpoint from {ckpt_path}")
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
|
||||||
|
if config.resume:
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||||
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
|
start_epoch = checkpoint["epoch"] + 1
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
if config['stage'] == 'stage_3':
|
||||||
|
visdial_step = checkpoint['visdial_step']
|
||||||
|
avsd_step = checkpoint['avsd_step']
|
||||||
|
next_step = checkpoint['nextqa_step']
|
||||||
|
|
||||||
|
|
||||||
|
if config['stage'] in ['stage_2', 'stage_3'] and config.use_moes:
|
||||||
|
# Init. the history expert erights with the caption expert weights
|
||||||
|
p_names = [
|
||||||
|
'moe_layers.{}.norm_hist.weight',
|
||||||
|
'moe_layers.{}.mlp_hist.fc1.weight',
|
||||||
|
'moe_layers.{}.mlp_hist.fc1.bias',
|
||||||
|
'moe_layers.{}.mlp_hist.fc2.weight',
|
||||||
|
'moe_layers.{}.mlp_hist.fc2.bias',
|
||||||
|
]
|
||||||
|
|
||||||
|
for moe_layer_idx in range(config.num_moe_modality_layers):
|
||||||
|
for p_name in p_names:
|
||||||
|
p_hist_name = p_name.format(moe_layer_idx)
|
||||||
|
if p_hist_name not in state_dict:
|
||||||
|
p_cap_name = p_hist_name.replace('hist', 'cap')
|
||||||
|
state_dict[p_hist_name] = state_dict[p_cap_name].clone()
|
||||||
|
|
||||||
|
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
logger.info(f"Loaded checkpoint from {ckpt_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("No pretrained checkpoint provided, training from scratch")
|
||||||
|
|
||||||
|
if config['training']:
|
||||||
|
optimizer = create_optimizer(config, model_without_ddp)
|
||||||
|
scheduler = create_scheduler(config, optimizer)
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
|
||||||
|
|
||||||
|
elif config['generating']:
|
||||||
|
model.llm.set_input_embeddings(model.text_embedding)
|
||||||
|
|
||||||
|
if config['distributed']:
|
||||||
|
|
||||||
|
static_graph=config.stage!='stage_1'
|
||||||
|
if len(config.media_train) > 0:
|
||||||
|
static_graph = False
|
||||||
|
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
model_without_ddp,
|
||||||
|
device_ids=[config['gpu']],
|
||||||
|
find_unused_parameters=find_unused_parameters, # `False` for image-only task
|
||||||
|
static_graph=static_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if config['stage'] == 'stage_3':
|
||||||
|
return (
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
global_step,
|
||||||
|
visdial_step,
|
||||||
|
avsd_step,
|
||||||
|
nextqa_step,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
model,
|
||||||
|
model_without_ddp,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
scaler,
|
||||||
|
start_epoch,
|
||||||
|
global_step,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_data(config):
|
||||||
|
logger.info("[INFO] Creating datasets")
|
||||||
|
|
||||||
|
# define the processors
|
||||||
|
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
|
||||||
|
|
||||||
|
if config['stage'] == 'stage_1':
|
||||||
|
text_processor = BlipCaptionProcessor(max_words=config.max_cap_len)
|
||||||
|
|
||||||
|
if config['debugging']:
|
||||||
|
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
|
||||||
|
else:
|
||||||
|
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'train')
|
||||||
|
|
||||||
|
val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
|
||||||
|
|
||||||
|
# cc3m_dataset = ConcatDataset([train_datasets['cc3m'], val_datasets['cc3m']])
|
||||||
|
|
||||||
|
# webvid_dataset = ConcatDataset([train_datasets['webvid'], val_datasets['webvid']])
|
||||||
|
|
||||||
|
# train_datasets = [cc3m_dataset, webvid_dataset]
|
||||||
|
train_datasets = list(train_datasets.values())
|
||||||
|
val_datasets = list(val_datasets.values())
|
||||||
|
|
||||||
|
batch_sizes = [config['batch_size_cc3m'], config['batch_size_webvid']]
|
||||||
|
num_samples = [len(d) for d in train_datasets]
|
||||||
|
config = set_training_steps(config, num_samples, batch_sizes)
|
||||||
|
|
||||||
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
||||||
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
||||||
|
|
||||||
|
# val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'test')
|
||||||
|
|
||||||
|
# val_dataloader = load_dataloaders(config, val_datasets, 'test', output_dict=True)
|
||||||
|
|
||||||
|
if config['stage'] == 'stage_2':
|
||||||
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
||||||
|
train_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'train')]
|
||||||
|
val_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'val')]
|
||||||
|
batch_sizes = [config['batch_size_champagne']]
|
||||||
|
num_samples = [len(d) for d in train_datasets]
|
||||||
|
config = set_training_steps(config, num_samples, batch_sizes)
|
||||||
|
|
||||||
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
||||||
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
||||||
|
|
||||||
|
|
||||||
|
if config['stage'] == 'stage_3':
|
||||||
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
||||||
|
train_datasets = []
|
||||||
|
val_datasets = []
|
||||||
|
for medium in config['media_train']:
|
||||||
|
if medium == 'visdial':
|
||||||
|
load_dataset_fn = load_visdial_dataset
|
||||||
|
elif medium == 'avsd':
|
||||||
|
load_dataset_fn = load_avsd_dataset
|
||||||
|
elif medium == 'nextqa':
|
||||||
|
load_dataset_fn = load_nextqa_dataset
|
||||||
|
# elif medium == 'champagne':
|
||||||
|
# load_dataset_fn = load_champagne_dataset
|
||||||
|
|
||||||
|
train_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'train'))
|
||||||
|
|
||||||
|
for medium in config['media_val']:
|
||||||
|
if medium == 'visdial':
|
||||||
|
load_dataset_fn = load_visdial_dataset
|
||||||
|
elif medium == 'avsd':
|
||||||
|
load_dataset_fn = load_avsd_dataset
|
||||||
|
elif medium == 'nextqa':
|
||||||
|
load_dataset_fn = load_nextqa_dataset
|
||||||
|
# elif medium == 'champagne':
|
||||||
|
# load_dataset_fn = load_champagne_dataset
|
||||||
|
|
||||||
|
val_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'val'))
|
||||||
|
|
||||||
|
batch_sizes = [d.batch_size for d in train_datasets]
|
||||||
|
num_samples = [len(d) for d in train_datasets]
|
||||||
|
config = set_training_steps(config, num_samples, batch_sizes)
|
||||||
|
|
||||||
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
||||||
|
|
||||||
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
||||||
|
|
||||||
|
return train_dataloaders, val_dataloaders
|
||||||
|
|
||||||
|
|
||||||
|
def setup_data_test(config):
|
||||||
|
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
|
||||||
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
||||||
|
|
||||||
|
if config.media_test == 'visdial':
|
||||||
|
load_dataset_fn = load_visdial_dataset
|
||||||
|
elif config.media_test == 'avsd':
|
||||||
|
load_dataset_fn = load_avsd_dataset
|
||||||
|
elif config.media_test == 'nextqa':
|
||||||
|
load_dataset_fn = load_nextqa_dataset
|
||||||
|
test_dataset = load_dataset_fn(config, vis_processor, text_processor, 'test')
|
||||||
|
|
||||||
|
test_dataloader = DataLoader(
|
||||||
|
test_dataset, shuffle=False, batch_size=test_dataset.batch_size)
|
||||||
|
|
||||||
|
return test_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
# def setup_data_test(config, args):
|
||||||
|
# tokenizer_experts = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
# tokenizer_enc_dec = None
|
||||||
|
# if config.enc_dec_family == 'flan_t5':
|
||||||
|
# tokenizer_enc_dec = T5Tokenizer.from_pretrained(config.enc_dec_name)
|
||||||
|
# elif config.enc_dec_family == 'bart':
|
||||||
|
# tokenizer_enc_dec = BartTokenizer.from_pretrained(config.enc_dec_name)
|
||||||
|
# if config['tie_embeddings']:
|
||||||
|
# tokenizer_experts = tokenizer_enc_dec
|
||||||
|
|
||||||
|
# if config['medium'] == 'avsd':
|
||||||
|
# test_dataset = AVSDDataSet(config, 'avsd', tokenizer_experts, tokenizer_enc_dec, 'test')
|
||||||
|
# test_dataloader = DataLoader(
|
||||||
|
# test_dataset, shuffle=False, batch_size=test_dataset.batch_size, collate_fn=test_dataset.collate_fn)
|
||||||
|
# return test_dataloader
|
266
models/utils.py
Normal file
266
models/utils.py
Normal file
|
@ -0,0 +1,266 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy import interpolate
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLM:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask_token: int,
|
||||||
|
padding_token: int,
|
||||||
|
no_mask_tokens: List[int],
|
||||||
|
n_tokens: int,
|
||||||
|
masking_prob: float = 0.15,
|
||||||
|
randomize_prob: float = 0.1,
|
||||||
|
no_change_prob: float = 0.1
|
||||||
|
):
|
||||||
|
self.mask_token = mask_token
|
||||||
|
self.padding_token = padding_token
|
||||||
|
self.no_mask_tokens = list(set(no_mask_tokens + [padding_token, mask_token]))
|
||||||
|
self.n_tokens = n_tokens
|
||||||
|
self.masking_prob = masking_prob
|
||||||
|
self.randomize_prob = randomize_prob
|
||||||
|
self.no_change_prob = no_change_prob
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
|
||||||
|
for tok in self.no_mask_tokens:
|
||||||
|
full_mask &= x != tok # unmask unwanted tokens --> 0
|
||||||
|
|
||||||
|
unchanged_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
|
||||||
|
random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
|
||||||
|
random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
|
||||||
|
random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
|
||||||
|
mask = full_mask & ~random_token_mask & ~unchanged_mask
|
||||||
|
|
||||||
|
y = x.clone().detach()
|
||||||
|
x.masked_fill_(mask, self.mask_token)
|
||||||
|
x[random_token_idx] = random_tokens
|
||||||
|
y.masked_fill_(~full_mask, self.padding_token)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _init_transformer_weights(module, initializer_range=0.02):
|
||||||
|
"""Initialize the weights. Copied from transformers ViT/Bert model init"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
|
||||||
|
"""
|
||||||
|
Add/Remove extra temporal_embeddings as needed.
|
||||||
|
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
|
||||||
|
|
||||||
|
temp_embed_old: (1, num_frames_old, 1, d)
|
||||||
|
temp_embed_new: (1, num_frames_new, 1, d)
|
||||||
|
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
|
||||||
|
"""
|
||||||
|
# TODO zero pad
|
||||||
|
num_frms_new = temp_embed_new.shape[1]
|
||||||
|
num_frms_old = temp_embed_old.shape[1]
|
||||||
|
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
|
||||||
|
if num_frms_new > num_frms_old:
|
||||||
|
if add_zero:
|
||||||
|
temp_embed_new[
|
||||||
|
:, :num_frms_old
|
||||||
|
] = temp_embed_old # untrained embeddings are zeros.
|
||||||
|
else:
|
||||||
|
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
|
||||||
|
elif num_frms_new < num_frms_old:
|
||||||
|
temp_embed_new = temp_embed_old[:, :num_frms_new]
|
||||||
|
else: # =
|
||||||
|
temp_embed_new = temp_embed_old
|
||||||
|
return temp_embed_new
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
|
||||||
|
"""
|
||||||
|
temp_embed_old: (1, num_frames_old, 1, d)
|
||||||
|
Returns:
|
||||||
|
temp_embed_new: (1, num_frames_new, 1, d)
|
||||||
|
"""
|
||||||
|
temp_embed_old = temp_embed_old.squeeze(2).permute(
|
||||||
|
0, 2, 1
|
||||||
|
) # (1, d, num_frames_old)
|
||||||
|
temp_embed_new = F.interpolate(
|
||||||
|
temp_embed_old, num_frames_new, mode="linear"
|
||||||
|
) # (1, d, num_frames_new)
|
||||||
|
temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
|
||||||
|
2
|
||||||
|
) # (1, num_frames_new, 1, d)
|
||||||
|
return temp_embed_new
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_embed_old: (1, L_old, d), pre-trained
|
||||||
|
pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
|
||||||
|
num_patches_new:
|
||||||
|
"""
|
||||||
|
# interpolate position embedding
|
||||||
|
embedding_size = pos_embed_old.shape[-1]
|
||||||
|
num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(num_patches_new ** 0.5)
|
||||||
|
|
||||||
|
if orig_size != new_size:
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
# the extra tokens seems always at the beginning of the position embedding
|
||||||
|
extra_tokens = pos_embed_old[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_old[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(
|
||||||
|
-1, orig_size, orig_size, embedding_size
|
||||||
|
).permute(0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
||||||
|
)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
|
||||||
|
return interpolated_pos_embed
|
||||||
|
else:
|
||||||
|
return pos_embed_old
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
state_dict_old: loaded state dict
|
||||||
|
state_dict_new: state dict for model with new image size
|
||||||
|
patch_shape_new: new model patch_shape
|
||||||
|
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
|
||||||
|
"""
|
||||||
|
all_keys = list(state_dict_old.keys())
|
||||||
|
for key in all_keys:
|
||||||
|
if "relative_position_index" in key:
|
||||||
|
state_dict_old.pop(key)
|
||||||
|
|
||||||
|
if "relative_position_bias_table" in key:
|
||||||
|
rel_pos_bias = state_dict_old[key]
|
||||||
|
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||||
|
dst_num_pos, _ = state_dict_new[key].size()
|
||||||
|
dst_patch_shape = patch_shape_new
|
||||||
|
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
|
||||||
|
dst_patch_shape[1] * 2 - 1
|
||||||
|
)
|
||||||
|
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
||||||
|
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
||||||
|
if src_size != dst_size:
|
||||||
|
# logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
|
||||||
|
# key, src_size, src_size, dst_size, dst_size))
|
||||||
|
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||||
|
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||||
|
|
||||||
|
def geometric_progression(a, r, n):
|
||||||
|
return a * (1.0 - r ** n) / (1.0 - r)
|
||||||
|
|
||||||
|
left, right = 1.01, 1.5
|
||||||
|
while right - left > 1e-6:
|
||||||
|
q = (left + right) / 2.0
|
||||||
|
gp = geometric_progression(1, q, src_size // 2)
|
||||||
|
if gp > dst_size // 2:
|
||||||
|
right = q
|
||||||
|
else:
|
||||||
|
left = q
|
||||||
|
|
||||||
|
# if q > 1.090307:
|
||||||
|
# q = 1.090307
|
||||||
|
|
||||||
|
dis = []
|
||||||
|
cur = 1
|
||||||
|
for i in range(src_size // 2):
|
||||||
|
dis.append(cur)
|
||||||
|
cur += q ** (i + 1)
|
||||||
|
|
||||||
|
r_ids = [-_ for _ in reversed(dis)]
|
||||||
|
|
||||||
|
x = r_ids + [0] + dis
|
||||||
|
y = r_ids + [0] + dis
|
||||||
|
|
||||||
|
t = dst_size // 2.0
|
||||||
|
dx = np.arange(-t, t + 0.1, 1.0)
|
||||||
|
dy = np.arange(-t, t + 0.1, 1.0)
|
||||||
|
|
||||||
|
# logger.info("Original positions = %s" % str(x))
|
||||||
|
# logger.info("Target positions = %s" % str(dx))
|
||||||
|
|
||||||
|
all_rel_pos_bias = []
|
||||||
|
|
||||||
|
for i in range(num_attn_heads):
|
||||||
|
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
||||||
|
f = interpolate.interp2d(x, y, z, kind="cubic")
|
||||||
|
all_rel_pos_bias.append(
|
||||||
|
torch.Tensor(f(dx, dy))
|
||||||
|
.contiguous()
|
||||||
|
.view(-1, 1)
|
||||||
|
.to(rel_pos_bias.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
||||||
|
|
||||||
|
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
||||||
|
state_dict_old[key] = new_rel_pos_bias
|
||||||
|
return state_dict_old
|
||||||
|
|
||||||
|
|
||||||
|
def tile(x, dim, n_tile):
|
||||||
|
init_dim = x.size(dim)
|
||||||
|
repeat_idx = [1] * x.dim()
|
||||||
|
repeat_idx[dim] = n_tile
|
||||||
|
x = x.repeat(*repeat_idx)
|
||||||
|
order_index = torch.LongTensor(
|
||||||
|
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||||
|
)
|
||||||
|
return torch.index_select(x, dim, order_index.to(x.device))
|
||||||
|
|
||||||
|
|
||||||
|
def mask_logits(target, mask):
|
||||||
|
return target * mask + (1 - mask) * (-1e10)
|
||||||
|
|
||||||
|
|
||||||
|
class AllGather(torch.autograd.Function):
|
||||||
|
"""An autograd function that performs allgather on a tensor."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, tensor, args):
|
||||||
|
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
||||||
|
torch.distributed.all_gather(output, tensor)
|
||||||
|
ctx.rank = args.rank
|
||||||
|
ctx.batch_size = tensor.shape[0]
|
||||||
|
return torch.cat(output, dim=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return (
|
||||||
|
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
allgather_wgrad = AllGather.apply
|
2213
models/v2dial.py
Normal file
2213
models/v2dial.py
Normal file
File diff suppressed because it is too large
Load diff
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue