commit a82bbc593eb430a687b43db6506e5d33314e2490 Author: Andreas Bulling Date: Tue Jun 24 08:38:09 2025 +0200 initial commit diff --git a/data/CGRUM.mp4 b/data/CGRUM.mp4 new file mode 100644 index 0000000..881e52a Binary files /dev/null and b/data/CGRUM.mp4 differ diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/avsd_dataset.py b/datasets/avsd_dataset.py new file mode 100644 index 0000000..20b9061 --- /dev/null +++ b/datasets/avsd_dataset.py @@ -0,0 +1,205 @@ +# coding: utf-8 +# author: noctli +import json +import os +import pickle +import logging +import numpy as np +from tqdm import tqdm +import torch +import torch.utils.data +from PIL import Image +from torch.utils.data import Dataset +from itertools import chain +from torchvision import transforms +from .utils import type_transform_helper +from itertools import chain +from .video_utils import read_frames_decord + + +def tokenize(text, tokenizer, return_tensor=False): + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + if return_tensor: + return torch.tensor(tokenized_text).long() + return tokenized_text + + +def get_dataset(config, split): + if split != 'test': + dialog_pth = config[f'anno_avsd_{split}'] + else: + dialog_pth = config['anno_avsd_test_dstc_{}'.format(config['dstc'])] + n_history = config['num_hist_turns_avsd'] + undisclosed_only = split == 'test' + dialog_data = json.load(open(dialog_pth, 'r')) + dialog_list = [] + vid_set = set() + pbar = tqdm(dialog_data['dialogs']) + pbar.set_description('[INFO] Loading AVSD - {}'.format(split)) + for dialog in pbar: + # if config['dstc'] != 10: + caption = dialog['caption'] + summary = dialog['summary'] + # else: + # caption = 'no' + # summary = 'no' + + questions = [d['question'] for d in dialog['dialog']] + answers = [d['answer'] for d in dialog['dialog']] + vid = dialog["image_id"] + vid_set.add(vid) + if undisclosed_only: + it = range(len(questions) - 1, len(questions)) + else: + it = range(len(questions)) + qalist=[] + history = [] + if undisclosed_only: + for n in range(len(questions)-1): + qalist.append(questions[n]) + qalist.append(answers[n]) + history=qalist[max(-len(qalist),-n_history*2):] + for n in it: + if undisclosed_only: + assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__' + question = questions[n] + answer = answers[n] + history.append(question) + if n_history == 0: + item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'summary': summary} + else: + item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'summary': summary} + + dialog_list.append(item) + qalist.append(question) + qalist.append(answer) + history=qalist[max(-len(qalist),-n_history*2):] + + return dialog_list + + +def build_input_from_segments(caption, history, reply, tokenizer, drop_caption=False): + """ Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """ + + bos, eos = tokenizer.convert_tokens_to_ids(['', '']) + sep = eos + + instance = {} + instance["lm_labels"] = reply + [eos] + caption = list(chain(*caption)) + + if not drop_caption: + # sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])] + + # NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just + # learn to copy it --> low train/val loss but no learning is happening + sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]] + else: + sequence = [[bos]] + [[sep] + s for s in history] + [[eos]] + + instance["input_ids"] = list(chain(*sequence)) + return instance + + +class AVSDDataSet(Dataset): + def __init__(self, config, medium, vis_processor, text_processor, split + # tokenizer, features=None, drop_rate=0.0, train=True + ): + self.config = config + self.medium = medium + self.vis_processor = vis_processor + self.text_processor = text_processor + self.split = split + self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)] + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + self.dialogs = get_dataset(config, split) + + if split == 'test': + self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']] + + num_samples = config['num_samples_{}'.format(self.medium)] + if num_samples > 0: + self.dialogs = self.dialogs[:num_samples] + + def __len__(self): + return len(self.dialogs) + + def load_vid(self, vid_id): + vid_dir_path = os.path.join(self.root_vis, vid_id + '.mp4') + + frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames) + frames = [self.vis_processor(f).unsqueeze(0) for f in frames] + + vis = torch.cat(frames, dim=0) + return vis + + def load_vid_old(self, vid_id): + # if vid_id == 'QQM8M': + # print('bla') + vid_dir_path = os.path.join(self.root_vis, vid_id) + frame_paths = [os.path.join(vid_dir_path, f) for f in os.listdir(vid_dir_path)] + frame_paths.sort() + num_avail_frames = len(frame_paths) + delta = int(num_avail_frames / (self.config['num_frames'] - 1)) + ran = list(range(0, num_avail_frames, delta)) + if len(ran) < self.config['num_frames']: + ran.extend([num_avail_frames - 1 for _ in range(self.config['num_frames'] - len(ran))]) + if len(ran) > self.config['num_frames']: + ran = ran[:self.config['num_frames']] + assert len(ran) == self.config['num_frames'], f"vid {vid_id} - loaded {len(ran)}/{len(frame_paths)} frames" + frame_paths = [frame_paths[i] for i in ran] + vis = [Image.open(p).convert('RGB') for p in frame_paths] + vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis] + vis = torch.cat(vis, dim=0) + vis = self.trans(vis) + return vis + + def __getitem__(self, index): + dialog = self.dialogs[index] + vid_id = dialog['vid'] + + caption = dialog['caption'] + summary = dialog['summary'] + history = dialog['history'] + answer = dialog['answer'] + + caption = self.text_processor(caption) + summary = self.text_processor(summary) + if self.config.dstc != 10: + caption = caption + ' ' + summary + + history = [self.text_processor(h) for h in history] + answer = self.text_processor(answer, remove_period=True) + + if self.config.embed_from_llm: + if self.config.llm_family in ['llama', 'mistral']: + cls_tok = '' + sep_tok = ' ' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '' + sep_tok = '' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '[CLS]' + sep_tok = '[SEP]' + bos_tok = '[SEP]' + eos_tok = '[SEP]' + + caption = cls_tok + caption + sep_tok + history = sep_tok.join(history) + history = history + sep_tok + + # load the video frames + vis = self.load_vid(vid_id) + + return vis, caption, history, answer, vid_id + + +def load_avsd_dataset(config, vis_processor, text_processor, split): + # data_file = config['anno_avsd_{}'.format(split)] + # dataset_list = get_dataset(config, split, tokenizer_enc_dec) + dataset = AVSDDataSet(config, 'avsd', vis_processor, text_processor, split) + return dataset diff --git a/datasets/champagne_dataset.py b/datasets/champagne_dataset.py new file mode 100644 index 0000000..01d6853 --- /dev/null +++ b/datasets/champagne_dataset.py @@ -0,0 +1,279 @@ +# coding: utf-8 +# author: noctli +import json +import os +import pickle +import logging +from tqdm import tqdm +import numpy as np +import torch +import torch.utils.data +from PIL import Image +from torch.utils.data import Dataset +from itertools import chain +from torchvision import transforms +from .utils import type_transform_helper + + +def tokenize(text, tokenizer, return_tensor=False): + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + if return_tensor: + return torch.tensor(tokenized_text).long() + return tokenized_text + + +def get_dataset(config, split): + + dialog_pth = config['anno_visdial_{}'.format(split)] + dialog_data = json.load(open(dialog_pth, 'r'))['data'] + all_answers = dialog_data['answers'] + all_questions = dialog_data['questions'] + dialog_list = [] + n_history = config['num_hist_turns'] + vid_set = set() + + pbar = tqdm(dialog_data['dialogs']) + pbar.set_description('[INFO] Loading VisDial - {}'.format(split)) + for dialog in pbar: + caption = dialog['caption'] + questions = [all_questions[d['question']] for d in dialog['dialog']] + answers = [all_answers[d['answer']] for d in dialog['dialog']] + + vid = dialog["image_id"] + vid_set.add(vid) + # if undisclosed_only: + # it = range(len(questions) - 1, len(questions)) + # else: + it = range(len(questions)) + qalist=[] + history = [] + # if undisclosed_only: + # for n in range(len(questions)-1): + # qalist.append(questions[n]) + # qalist.append(answers[n]) + # history=qalist[max(-len(qalist),-n_history*2):] + + for n in it: + # if undisclosed_only: + # assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__' + question = questions[n] + answer = answers[n] + history.append(question) + if n_history == 0: + item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption} + else: + item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption} + dialog_list.append(item) + qalist.append(question) + qalist.append(answer) + history=qalist[max(-len(qalist),-n_history*2):] + return dialog_list + + +class Champagne(Dataset): + def __init__(self, config, medium, vis_processor, text_processor, split): + + self.config = config + self.medium = medium + self.vis_processor = vis_processor + self.text_processor = text_processor + self.split = split + self.batch_size = config['batch_size_{}'.format(medium)] + + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + + # get the mapping between caption and image/video + mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None) + with open(mapping_path, 'rb') as f: + self.mapping = pickle.load(f) + + ids = list(self.mapping.keys()) + ids.sort() + + # reserve some samples for validation + if split == 'train': + self.ids = ids[config.num_val_samples:] + elif split == 'val': + self.ids = ids[:config.num_val_samples] + + num_samples = config['num_samples_{}'.format(self.medium)] + if num_samples > 0: + self.ids = self.ids[:num_samples] + + def __len__(self): + return len(self.ids) + + + def padding(self, seq, pad_token, max_len=None): + if max_len is None: + max_len = max([i.size(0) for i in seq]) + if len(seq[0].size()) == 1: + result = torch.ones((len(seq), max_len)).long() * pad_token + else: + result = torch.ones((len(seq), max_len, seq[0].size(-1))).float() + for i in range(len(seq)): + result[i, :seq[i].size(0)] = seq[i] + orig_len = [s.size(0) for s in seq] + return result, orig_len + + def __getitem__(self, index): + item = self.mapping[self.ids[index]] + # load the videos + pth = os.path.join(self.root_vis, item['path']) + f_names = os.listdir(pth) + if len(f_names) == 0: + with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f: + item = pickle.load(f) + + # load the videos + pth = os.path.join(self.root_vis, item['path']) + f_names = os.listdir(pth) + f_names.sort() + + if len(f_names) < self.config['num_frames']: + f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names)) + elif len(f_names) > self.config['num_frames']: + f_names = f_names[:self.config['num_frames']] + + pth = [os.path.join(pth, f_name) for f_name in f_names] + try: + vis = [Image.open(p).convert('RGB') for p in pth] + except: + with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f: + item = pickle.load(f) + + # load the videos + pth = os.path.join(self.root_vis, item['path']) + f_names = os.listdir(pth) + f_names.sort() + + pth = [os.path.join(pth, f_name) for f_name in f_names] + vis = [Image.open(p).convert('RGB') for p in pth] + + vis = [self.vis_processor(v).unsqueeze(0) for v in vis] + vis = torch.cat(vis, dim=0) + + dialog = item['dialog'] + + caption = dialog['caption'] + history = dialog['history'] + answer = dialog['answer'] + + caption = self.text_processor(caption) + history = [self.text_processor(h) for h in history] + answer = self.text_processor(answer, remove_period=True) + + if self.config.embed_from_llm: + if self.config.llm_family in ['llama', 'mistral']: + cls_tok = '' + sep_tok = '' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '' + sep_tok = '' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '[CLS]' + sep_tok = '[SEP]' + bos_tok = '[SEP]' + eos_tok = '[SEP]' + + # preprocess the textual data + caption = cls_tok + caption + sep_tok + history = sep_tok.join(history) + history = history + sep_tok + # if self.config.llm_family == 'flan_t5': + # answer = ' ' + self.text_processor(answer) + ' ' + # else: + # answer = self.text_processor(answer) + eos_tok + + return vis, caption, history, answer + + + # def collate_fn(self, batch): + + # BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['', '', '']) + + # vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], [] + # batch_size = len(batch) + # for b in batch: + # vis_list.append(b[0]) + # cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS] + # cap_list.append(torch.tensor(cap)) + # if len(b[2])!=0: + # hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]] + # hist_list.append(torch.tensor(list(chain(*hist)))) + # else: + # hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS] + # hist_list.append(torch.tensor(hist)) + + # ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS] + # ques_list.append(torch.tensor(ques)) + # ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS] + # ans_list.append(torch.tensor(ans)) + # index_list.append(b[5]) + # vid_id_list.append(b[6]) + + # # pad and keep track of the original lengths + # cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id) + # hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id) + # ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id) + # ans_input_ids, _ = self.padding(ans_list, -100) + + # cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id + # hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id + # ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id + + # total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)] + # max_len = max(total_orig_lens) + + # dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id) + # enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool) + # for i, l in enumerate(total_orig_lens): + # enc_dec_attention_mask[i][:l] = True + # # add the masking of the visual input + # num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])] + # if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']: + # vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries + # else: + # vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries + + # enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1) + # # Now prepare the data + # vis = torch.stack(vis_list, dim=0) + # cap = { + # 'input_ids': cap_input_ids, + # 'attention_mask': cap_attention_mask, + # 'orig_lens': cap_orig_lens + # } + + # hist = { + # 'input_ids': hist_input_ids, + # 'attention_mask': hist_attention_mask, + # 'orig_lens': hist_orig_lens + # } + + # ques = { + # 'input_ids': ques_input_ids, + # 'attention_mask': ques_attention_mask, + # 'orig_lens': ques_orig_lens + # } + + # ans = { + # 'input_ids': ans_input_ids, + # } + + # enc_dec_input = { + # 'input_ids': dummy_input_ids_enc_dec, + # 'attention_mask': enc_dec_attention_mask, + # } + + # index = torch.tensor(index_list) + # return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list + + +def load_champagne_dataset(config, vis_processor, text_processor, split): + dataset = Champagne(config, 'champagne', vis_processor, text_processor, split) + return dataset \ No newline at end of file diff --git a/datasets/dataloader.py b/datasets/dataloader.py new file mode 100644 index 0000000..0b72870 --- /dev/null +++ b/datasets/dataloader.py @@ -0,0 +1,137 @@ +""" +From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py +""" + +import torch +from torch.utils.data import DataLoader, Dataset, ConcatDataset +import torch.distributed as dist +from utils.dist import * +import random +import logging + +logger = logging.getLogger(__name__) + + +class MetaLoader(object): + """ wraps multiple data loader """ + def __init__(self, name2loader): + """Iterates over multiple dataloaders, it ensures all processes + work on data from the same dataloader. This loader will end when + the shorter dataloader raises StopIteration exception. + + loaders: Dict, {name: dataloader} + """ + self.name2loader = name2loader + self.name2iter = {name: iter(l) for name, l in name2loader.items()} + name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} + index2name = {v: k for k, v in name2index.items()} + + iter_order = [] + for n, l in name2loader.items(): + iter_order.extend([name2index[n]]*len(l)) + + random.shuffle(iter_order) + iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) + + # sync + if is_dist_avail_and_initialized(): + # make sure all processes have the same order so that + # each step they will have data from the same loader + dist.broadcast(iter_order, src=0) + self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] + + logger.info(str(self)) + + def __str__(self): + output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] + for idx, (name, loader) in enumerate(self.name2loader.items()): + output.append( + f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " + ) + return "\n".join(output) + + def __len__(self): + return len(self.iter_order) + + def __iter__(self): + """ this iterator will run indefinitely """ + for name in self.iter_order: + _iter = self.name2iter[name] + batch = next(_iter) + yield name, batch + + +def load_dataloaders(config, datasets, split, output_dict=False): + if isinstance(datasets, dict): + datasets = list(datasets.values()) + shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets) + if config['distributed'] and split != 'test': + num_tasks = get_world_size() + global_rank = get_rank() + samplers = create_samplers( + datasets, shuffles, num_tasks, global_rank + ) + else: + samplers = [None] * len(datasets) + + batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets] + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, ConcatDataset): + collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None)) + else: + collate_fns.append(getattr(dataset, 'collate_fn', None)) + + loaders = create_loader( + datasets, + samplers, + batch_size=batch_size, + num_workers=[config.num_workers] * len(datasets), + is_trains=shuffles, + collate_fns=collate_fns, + ) # [0] + loaders_dict = {} + if output_dict: + for l in loaders: + if isinstance(l.dataset, ConcatDataset): + loaders_dict[l.dataset.datasets[0].medium] = l + else: + loaders_dict[l.dataset.medium] = l + return loaders_dict + return loaders + + +def create_samplers(datasets, shuffles, num_tasks, global_rank): + samplers = [] + for dataset, shuffle in zip(datasets, shuffles): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle + ) + samplers.append(sampler) + return samplers + + +def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): + loaders = [] + for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( + datasets, samplers, batch_size, num_workers, is_trains, collate_fns + ): + if is_train: + shuffle = sampler is None + drop_last = True + else: + shuffle = False + drop_last = True + loader = DataLoader( + dataset, + batch_size=bs, + num_workers=n_worker, + pin_memory=False, + sampler=sampler, + shuffle=shuffle, + collate_fn=collate_fn, + drop_last=drop_last, + persistent_workers=True if n_worker > 0 else False, + ) + loaders.append(loader) + return loaders \ No newline at end of file diff --git a/datasets/nextqa_dataset.py b/datasets/nextqa_dataset.py new file mode 100644 index 0000000..50d5e0a --- /dev/null +++ b/datasets/nextqa_dataset.py @@ -0,0 +1,86 @@ +import os +import pandas as pd +# import h5py +import json +import numpy as np +import torch +from torch.utils.data import Dataset +from .video_utils import read_frames_decord + + +def load_file(file_name): + annos = None + if os.path.splitext(file_name)[-1] == '.csv': + return pd.read_csv(file_name) + with open(file_name, 'r') as fp: + if os.path.splitext(file_name)[1]== '.txt': + annos = fp.readlines() + annos = [line.rstrip() for line in annos] + if os.path.splitext(file_name)[1] == '.json': + annos = json.load(fp) + return annos + + +class NextQADataset(Dataset): + def __init__(self, config, medium, vis_processor, text_processor, split): + + super().__init__() + self.config = config + self.medium = medium + self.vis_processor = vis_processor + self.text_processor = text_processor + self.split = split + + self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)] + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + with open(config['vid_mapping_nextqa'], 'r') as f: + self.video_mapping = json.load(f) + + self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)]) + + if split == 'test': + self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']] + self.captions = load_file(self.config['next_qa_captions_{}'.format(split)]) + else: + self.captions = None + + num_samples = config['num_samples_{}'.format(self.medium)] + if num_samples > 0: + self.sample_list = self.sample_list[:num_samples] + + def __len__(self): + return len(self.sample_list) + + + def load_vid(self, vid_id): + vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4') + + frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames) + frames = [self.vis_processor(f).unsqueeze(0) for f in frames] + + vis = torch.cat(frames, dim=0) + return vis + + def __getitem__(self, idx): + if self.split == 'test': + idx += self.config['start_idx_gen'] + + cur_sample = self.sample_list.loc[idx] + video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\ + str(cur_sample['answer']), str(cur_sample['qid']) + + history = self.text_processor(ques) + answer = self.text_processor(ans) + if self.split == 'test': + caption = self.text_processor(self.captions[video_id]) + else: + caption = self.text_processor('please answer the following question based on the video') + vis = self.load_vid(video_id) + + return vis, caption, history, answer, video_id, qid + +def load_nextqa_dataset(config, vis_processor, text_processor, split): + # data_file = config['anno_avsd_{}'.format(split)] + # dataset_list = get_dataset(config, split, tokenizer_enc_dec) + dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split) + return dataset diff --git a/datasets/pretraining.py b/datasets/pretraining.py new file mode 100644 index 0000000..13d6876 --- /dev/null +++ b/datasets/pretraining.py @@ -0,0 +1,156 @@ +from torch.utils.data import Dataset +import pickle +import os + +import torch +from PIL import Image +import numpy as np +from torchvision import transforms +import random + +from .utils import pre_text, type_transform_helper, load_anno, open_img + +class CapDataset(Dataset): + def __init__(self, config, medium, vis_processor, text_processor, split): + super(CapDataset, self).__init__() + self.config = config + self.batch_size = config['batch_size_{}'.format(medium)] + self.medium = medium # "webvid / cc3m / msrvtt" + self.vis_processor = vis_processor + self.text_processor = text_processor + self.split = split # train / val / test + + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + + # get the mapping between caption and image/video + mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None) + with open(mapping_path, 'rb') as f: + self.mapping = pickle.load(f) + + # These are the main ids of the dataset (typically one pro image/vid) + self.ids = list(self.mapping.keys()) + num_samples = config['num_samples_{}'.format(self.medium)] + if num_samples > 0: + self.ids = self.ids[:num_samples] + + def __len__(self): + return len(self.ids) + + def __getitem__(self, index): + item = self.mapping[self.ids[index]] + # _id = self.ids[index] + ############################# Textal features ############################# + caption = item['caption'] + # caption_ = pre_text(caption) + caption = self.text_processor(caption) + # add [CLS] token + caption = '[CLS] ' + caption + + if self.medium == 'cc3m': + pth = os.path.join(self.root_vis, item['file']) + vis = open_img(pth) + vis = self.vis_processor(vis).unsqueeze(0) + else: + pth = os.path.join(self.root_vis, item['file']) + f_names = os.listdir(pth) + f_names.sort(key=lambda f_n: int(f_n.split('.')[0])) + pth = [os.path.join(pth, f_name) for f_name in f_names] + vis = [Image.open(p).convert('RGB') for p in pth] + vis = [self.vis_processor(v).unsqueeze(0) for v in vis] + vis = torch.cat(vis, dim=0) + + # Get negative vis + neg_index = random.randint(0, len(self) - 1) + while neg_index == index: + neg_index = random.randint(0, len(self) - 1) + + neg_item = self.mapping[self.ids[neg_index]] + + if self.medium == 'cc3m': + neg_pth = os.path.join(self.root_vis, neg_item['file']) + neg_vis = open_img(neg_pth) + neg_vis = self.vis_processor(neg_vis).unsqueeze(0) + else: + neg_pth = os.path.join(self.root_vis, neg_item['file']) + neg_f_names = os.listdir(neg_pth) + neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0])) + neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names] + neg_vis = [Image.open(p).convert('RGB') for p in neg_pth] + neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis] + neg_vis = torch.cat(neg_vis, dim=0) + + # return caption, vis + return vis, caption, neg_vis + + +class VideoTextRetDataset(Dataset): + def __init__(self, config, vis_processor, text_processor, medium, split): + super(VideoTextRetDataset, self).__init__() + + self.config = config + self.batch_size = config['batch_size_{}'.format(medium)] + self.medium = medium # "webvid / cc3m / msrvtt" + self.vis_processor = vis_processor + self.text_processor = text_processor + self.split = split # train / val / test + + + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + + anno_path = config['annotation_{}_{}'.format(medium, split)] + self.raw_anno_list = load_anno(anno_path) + self.text = [] + self.vis = [] + self.txt2vis = {} + self.vis2txt = {} + self.build_data() + self.anno_list = [dict(vis=v) for v in self.vis] + # print('bla') + + def __len__(self): + return len(self.anno_list) + + def __getitem__(self, index): + pth = self.anno_list[index]['vis'] + f_names = os.listdir(pth) + f_names.sort(key=lambda f_n: int(f_n.split('.')[0])) + pth = [os.path.join(pth, f_name) for f_name in f_names] + vis = [Image.open(p).convert('RGB') for p in pth] + vis = [self.vis_processor(v) for v in vis] + # vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis] + vis = torch.cat(vis, dim=0) + # vis = self.trans(vis) + + return vis, index + + def build_data(self): + """each image may have multiple ground_truth text, e.g., COCO and Flickr30K""" + txt_id = 0 + for vis_id, ann in enumerate(self.raw_anno_list): + self.vis.append(ann["vis"]) + self.vis2txt[vis_id] = [] + _captions = ann["caption"] \ + if isinstance(ann["caption"], list) else [ann["caption"], ] + for i, caption in enumerate(_captions): + # self.text.append(pre_text(caption)) + self.text.append(self.text_processor(caption)) + self.vis2txt[vis_id].append(txt_id) + self.txt2vis[txt_id] = vis_id + txt_id += 1 + + +def load_datasets(config, vis_processor, text_processor, split): + if config['stage'] == 'stage_1': + if split != 'test': + cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split) + webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split) + datasets = { + 'cc3m': cc3m_dataset, + 'webvid': webvid_dataset + } + else: # Test with msrvtt_1k --> video retieval + msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split) + datasets = { + 'msrvtt': msrvtt_dataset + } + return datasets \ No newline at end of file diff --git a/datasets/utils.py b/datasets/utils.py new file mode 100644 index 0000000..0ef9df2 --- /dev/null +++ b/datasets/utils.py @@ -0,0 +1,83 @@ +import os +import re +import json +from tqdm import trange +from utils.dist import is_main_process +from torch.utils.data import Dataset, ConcatDataset +from PIL import Image +import numpy as np + +def open_img(img_pth): + try: + img = Image.open(img_pth).convert('RGB') + return img + except: + img = np.random.randint(0, high=256, size=(224,224, 3)) + img = Image.fromarray(img, 'RGB') + return img + + +def pre_text(text, max_l=None): + text = re.sub(r"(['!?\"()*#:;~])", '', text.lower()) + text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') + + text = re.sub(r"\s{2,}", ' ', text) + text = text.rstrip('\n').strip(' ') + + if max_l: # truncate + words = text.split(' ') + if len(words) > max_l: + text = ' '.join(words[:max_l]) + return text + + +def get_datasets_media(dataloaders): + media = {} + for dataloader in dataloaders: + if isinstance(dataloader.dataset, ConcatDataset): + media[dataloader.dataset.datasets[0].medium] = dataloader + else: + media[dataloader.dataset.medium] = dataloader + + # media = [dataloader.dataset.medium for dataloader in dataloaders] + return media + +def type_transform_helper(x): + return x.float().div(255.0) + +def load_anno(ann_file_list): + """[summary] + + Args: + ann_file_list (List[List[str, str]] or List[str, str]): + the latter will be automatically converted to the former. + Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video']) + which specifies the data type, video or image + + Returns: + List(dict): each dict is { + image: str or List[str], # image_path, + caption: str or List[str] # caption text string + } + """ + if isinstance(ann_file_list[0], str): + ann_file_list = [ann_file_list] + + ann = [] + for d in ann_file_list: + data_root = d[1] + fp = d[0] + is_video = len(d) == 3 and d[2] == "video" + cur_ann = json.load(open(fp, "r")) + iterator = trange(len(cur_ann), desc=f"Loading {fp}") \ + if is_main_process() else range(len(cur_ann)) + for idx in iterator: + key = "video" if is_video else "image" + video_id = cur_ann[idx][key][5:].split('.')[0] + # unified to have the same key for data path + # if isinstance(cur_ann[idx][key], str): + cur_ann[idx]["vis"] = os.path.join(data_root, video_id) + # else: # list + # cur_ann[idx]["vis"] = [os.path.join(data_root, e) for e in cur_ann[idx][key]] + ann += cur_ann + return ann \ No newline at end of file diff --git a/datasets/video_utils.py b/datasets/video_utils.py new file mode 100644 index 0000000..8e5be71 --- /dev/null +++ b/datasets/video_utils.py @@ -0,0 +1,97 @@ +""" +Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py +""" +import random +import decord +from PIL import Image +import numpy as np +import math +decord.bridge.set_bridge("torch") + + +def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: + """ + Converts a present time with the given time base and start_pts offset to seconds. + + Returns: + time_in_seconds (float): The corresponding time in seconds. + + https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 + """ + if pts == math.inf: + return math.inf + + return int(pts - start_pts) * time_base + + +def get_pyav_video_duration(video_reader): + video_stream = video_reader.streams.video[0] + video_duration = pts_to_secs( + video_stream.duration, + video_stream.time_base, + video_stream.start_time + ) + return float(video_duration) + + +def get_frame_indices_by_fps(): + pass + + +def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): + if sample in ["rand", "middle"]: + acc_samples = min(num_frames, vlen) + # split the video into `acc_samples` intervals, and sample from each interval. + intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) + ranges = [] + for idx, interv in enumerate(intervals[:-1]): + ranges.append((interv, intervals[idx + 1] - 1)) + if sample == 'rand': + try: + frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] + except: + frame_indices = np.random.permutation(vlen)[:acc_samples] + frame_indices.sort() + frame_indices = list(frame_indices) + elif fix_start is not None: + frame_indices = [x[0] + fix_start for x in ranges] + elif sample == 'middle': + frame_indices = [(x[0] + x[1]) // 2 for x in ranges] + else: + raise NotImplementedError + + if len(frame_indices) < num_frames: # padded with last frame + padded_frame_indices = [frame_indices[-1]] * num_frames + padded_frame_indices[:len(frame_indices)] = frame_indices + frame_indices = padded_frame_indices + elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps + output_fps = float(sample[3:]) + duration = float(vlen) / input_fps + delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents + frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) + frame_indices = np.around(frame_seconds * input_fps).astype(int) + frame_indices = [e for e in frame_indices if e < vlen] + if max_num_frames > 0 and len(frame_indices) > max_num_frames: + frame_indices = frame_indices[:max_num_frames] + # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) + else: + raise ValueError + return frame_indices + + +def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1): + video_reader = decord.VideoReader(video_path, num_threads=1) + vlen = len(video_reader) + fps = video_reader.get_avg_fps() + duration = vlen / float(fps) + frame_indices = get_frame_indices( + num_frames, vlen, sample=sample, fix_start=fix_start, + input_fps=fps, max_num_frames=max_num_frames + ) + frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 + frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 + frames = frames.split(1, dim=0) + + frames = [Image.fromarray(f.squeeze().numpy(), mode='RGB') for f in frames] + # frames = frames.numpy() # convert to numpy + return frames, frame_indices, duration diff --git a/datasets/visdial_dataset.py b/datasets/visdial_dataset.py new file mode 100644 index 0000000..1382270 --- /dev/null +++ b/datasets/visdial_dataset.py @@ -0,0 +1,183 @@ +# coding: utf-8 +# author: noctli +import json +import os +import pickle +import logging +from tqdm import tqdm +import numpy as np +import torch +import torch.utils.data +from PIL import Image +from torch.utils.data import Dataset +from itertools import chain +from torchvision import transforms +from .utils import type_transform_helper +from .utils import open_img + +def tokenize(text, tokenizer, return_tensor=False): + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + if return_tensor: + return torch.tensor(tokenized_text).long() + return tokenized_text + + +def get_dataset(config, split): + + dialog_pth = config['anno_visdial_{}'.format(split)] + dialog_data = json.load(open(dialog_pth, 'r'))['data'] + + all_answers = dialog_data['answers'] + all_questions = dialog_data['questions'] + dialog_list = [] + n_history = config['num_hist_turns_visdial'] + vid_set = set() + undisclosed_only = False + + pbar = tqdm(dialog_data['dialogs']) + pbar.set_description('[INFO] Loading VisDial - {}'.format(split)) + for dialog in pbar: + caption = dialog['caption'] + ' .' + questions = [all_questions[d['question']] + ' ?' for d in dialog['dialog']] + answers = [all_answers[d['answer']] + ' .' for d in dialog['dialog']] + + # answer_opts = [[all_answers[key] for key in d['answer_options']] for d in dialog['dialog']] + # if 'test' in config['anno_visdial_{}'.format(split)]: + # gt_indices = [-1 for _ in range(len(questions))] + # else: + # gt_indices = [d['gt_index'] for d in dialog['dialog']] + + vid = dialog['image_id'] + vid_set.add(vid) + if undisclosed_only: + it = range(len(questions) - 1, len(questions)) + else: + it = range(len(questions)) + + qalist=[] + history = [] + if undisclosed_only: + for n in range(len(questions)-1): + qalist.append(questions[n]) + qalist.append(answers[n]) + history=qalist[max(-len(qalist),-n_history*2):] + + for n in it: + if undisclosed_only: + assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__' + question = questions[n] + answer = answers[n] + # answer_opt = answer_opts[n] + # gt_index = gt_indices[n] + history.append(question) + # if n_history == 0: + # item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index} + # else: + # item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index} + + if n_history == 0: + item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1} + else: + item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1} + + + + dialog_list.append(item) + qalist.append(question) + qalist.append(answer) + history=qalist[max(-len(qalist),-n_history*2):] + + return dialog_list + + +class VisDial(Dataset): + def __init__(self, config, medium, vis_processor, text_processor, split + # tokenizer, features=None, drop_rate=0.0, train=True + ): + self.config = config + self.medium = medium + self.split = split + self.vis_processor = vis_processor + self.text_processor = text_processor + self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)] + self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)] + + self.dialogs = get_dataset(config, split) + + if split == 'test': + self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']] + + num_samples = config['num_samples_{}'.format(self.medium)] + if num_samples > 0: + self.dialogs = self.dialogs[:num_samples] + + def __len__(self): + return len(self.dialogs) + + def load_img(self, vid_id): + file_pth = os.path.join(self.root_vis, f'{vid_id}.jpg') + vis = open_img(file_pth) + vis = self.vis_processor(vis).unsqueeze(0) + return vis + + def __getitem__(self, index): + dialog = self.dialogs[index] + + vid_id = dialog['vid'] + caption = dialog['caption'] + history = dialog['history'] + answer = dialog['answer'] + d_round = dialog['round'] + + caption = self.text_processor(caption) + history = [self.text_processor(h) for h in history] + answer = self.text_processor(answer, remove_period=True) + + + # if self.split == 'test': + # answer_opts = dialog['answer_opts'] + # answer_opts = [self.text_processor(a) for a in answer_opts] + + # gt_index = dialog['gt_index'] + # dialog_round = dialog['round'] + + # dense_key = str(vid_id) + '_' + str(dialog_round) + # gt_relevance = self.dense_annos.get(dense_key, -1) + # # eval_data = (answer_opts, gt_index, gt_relevance) + + + if self.config.embed_from_llm: + if self.config.llm_family in ['llama', 'mistral']: + cls_tok = '' + sep_tok = ' ' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '' + sep_tok = '' + bos_tok = '' + eos_tok = '' + else: + cls_tok = '[CLS]' + sep_tok = '[SEP]' + bos_tok = '[SEP]' + eos_tok = '[SEP]' + + caption = cls_tok + caption + sep_tok + history = sep_tok.join(history) + history = history + sep_tok + + # load the video frames + vis = self.load_img(vid_id) + + # if self.split == 'test': + # return vis, caption, history, answer, vid_id, answer_opts, gt_relevance, gt_index + + # else: + return vis, caption, history, answer, vid_id, d_round + + + +def load_visdial_dataset(config, vis_processor, text_processor, split): + dataset = VisDial(config, 'visdial', vis_processor, text_processor, split) + return dataset diff --git a/emergency/item.pkl b/emergency/item.pkl new file mode 100644 index 0000000..9289657 Binary files /dev/null and b/emergency/item.pkl differ diff --git a/eval_visdial.py b/eval_visdial.py new file mode 100644 index 0000000..e73684f --- /dev/null +++ b/eval_visdial.py @@ -0,0 +1,81 @@ +import os +import torch +import json +from utils.metrcis import SparseGTMetrics, NDCG +from Levenshtein import ratio + + +output_dir = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial' + +file_paths = os.listdir(output_dir) +file_paths = list(filter(lambda f: 'part' in f , file_paths)) +name = file_paths[0] +file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths)) + +results = {} +count = 0 +for pth in file_paths: + with open(pth, 'r') as f: + partial_res = json.load(f) + count += len(partial_res) + results.update(partial_res) + # dialogs.extend(data['dialogs']) + os.remove(pth) + + +name = "".join(name.split('-')[:-1]) + '.json' +output_path = os.path.join(output_dir, name) +with open(output_path, 'w') as f: + json.dump(results, f, indent=4) + + +# result_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial/zeroshot_visdial_after_champagne_googleflant5large_results_dstc8_beam_depth_8_lenPen_0.3.json' + +# with open(result_path, 'r') as f: +# results = json.load(f) + +annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json' +dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json' + +with open(annos_path, 'r') as f: + data = json.load(f)['data'] + +all_answers = data['answers'] +all_questions = data['questions'] + + +dialogs = data['dialogs'] + +dialogs_dict = {} + +for dialog in dialogs: + image_id = dialog['image_id'] + for i, turn in enumerate(dialog['dialog']): + answer_opts = [all_answers[a] for a in turn['answer_options']] + dialogs_dict[str(image_id) + '_' + str(i+1)] = { + 'answer_opts': answer_opts, + 'gt_index': turn['gt_index'] + } + # print('bla') + +with open(dense_annos_path, 'r') as f: + dense_data = json.load(f) + +dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data} + +sparse_metrics = SparseGTMetrics() +ndcg = NDCG() + +for res_key, res in results.items(): + answer_opts = dialogs_dict[res_key]['answer_opts'] + gt_index = torch.tensor(dialogs_dict[res_key]['gt_index']) + + scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0) + sparse_metrics.observe(scores, gt_index) + if res_key in dense_data: + gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0) + ndcg.observe(scores.squeeze(0), gt_relevance) + # print('bla') + +print(sparse_metrics.retrieve()) +print(ndcg.retrieve()) diff --git a/eval_visdial_sentence_embeddings.py b/eval_visdial_sentence_embeddings.py new file mode 100644 index 0000000..527e160 --- /dev/null +++ b/eval_visdial_sentence_embeddings.py @@ -0,0 +1,273 @@ +from sentence_transformers.cross_encoder import CrossEncoder +import os +import torch +import json +import numpy as np + +def scores_to_ranks(scores: torch.Tensor): + """Convert model output scores into ranks.""" + batch_size, num_rounds, num_options = scores.size() + scores = scores.view(-1, num_options) + + # sort in descending order - largest score gets highest rank + sorted_ranks, ranked_idx = scores.sort(1, descending=True) + + # i-th position in ranked_idx specifies which score shall take this + # position but we want i-th position to have rank of score at that + # position, do this conversion + ranks = ranked_idx.clone().fill_(0) + for i in range(ranked_idx.size(0)): + for j in range(num_options): + ranks[i][ranked_idx[i][j]] = j + # convert from 0-99 ranks to 1-100 ranks + ranks += 1 + ranks = ranks.view(batch_size, num_rounds, num_options) + return ranks + + +class SparseGTMetrics(object): + """ + A class to accumulate all metrics with sparse ground truth annotations. + These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. + """ + + def __init__(self): + self._rank_list = [] + + def observe( + self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor + ): + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, num_rounds, num_options) + predicted_ranks = scores_to_ranks(predicted_scores) + batch_size, num_rounds, num_options = predicted_ranks.size() + + # collapse batch dimension + predicted_ranks = predicted_ranks.view( + batch_size * num_rounds, num_options + ) + + # shape: (batch_size * num_rounds, ) + target_ranks = target_ranks.view(batch_size * num_rounds).long() + + # shape: (batch_size * num_rounds, ) + predicted_gt_ranks = predicted_ranks[ + torch.arange(batch_size * num_rounds), target_ranks + ] + self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) + + def retrieve(self, reset: bool = True): + num_examples = len(self._rank_list) + if num_examples > 0: + # convert to numpy array for easy calculation. + __rank_list = torch.tensor(self._rank_list).float() + metrics = { + "r@1": torch.mean((__rank_list <= 1).float()).item(), + "r@5": torch.mean((__rank_list <= 5).float()).item(), + "r@10": torch.mean((__rank_list <= 10).float()).item(), + "mean": torch.mean(__rank_list).item(), + "mrr": torch.mean(__rank_list.reciprocal()).item(), + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._rank_list = [] + + +class NDCG(object): + def __init__(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + + def observe( + self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor + ): + """ + Observe model output scores and target ground truth relevance and + accumulate NDCG metric. + + Parameters + ---------- + predicted_scores: torch.Tensor + A tensor of shape (batch_size, num_options), because dense + annotations are available for 1 randomly picked round out of 10. + target_relevance: torch.Tensor + A tensor of shape same as predicted scores, indicating ground truth + relevance of each answer option for a particular round. + """ + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, 1, num_options) + predicted_scores = predicted_scores.unsqueeze(1) + predicted_ranks = scores_to_ranks(predicted_scores) + + # shape: (batch_size, num_options) + predicted_ranks = predicted_ranks.squeeze(1) + batch_size, num_options = predicted_ranks.size() + + k = torch.sum(target_relevance != 0, dim=-1) + + # shape: (batch_size, num_options) + _, rankings = torch.sort(predicted_ranks, dim=-1) + # Sort relevance in descending order so highest relevance gets top rnk. + _, best_rankings = torch.sort( + target_relevance, dim=-1, descending=True + ) + + # shape: (batch_size, ) + batch_ndcg = [] + for batch_index in range(batch_size): + + num_relevant = k[batch_index] + dcg = self._dcg( + rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + best_dcg = self._dcg( + best_rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + batch_ndcg.append(dcg / best_dcg) + + self._ndcg_denominator += batch_size + self._ndcg_numerator += sum(batch_ndcg) + + def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): + sorted_relevance = relevance[rankings].cpu().float() + discounts = torch.log2(torch.arange(len(rankings)).float() + 2) + return torch.sum(sorted_relevance / discounts, dim=-1) + + def retrieve(self, reset: bool = True): + if self._ndcg_denominator > 0: + metrics = { + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + + +annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json' +with open(annos_path, 'r') as f: + data = json.load(f)['data'] + +dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json' +with open(dense_annos_path, 'r') as f: + dense_data = json.load(f) + +dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data} + +results_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial_before_supplementary/zeroshot_visdial_after_avsd_4_frames_3_rounds_ft_fp16_googleflant5large_results_dstc10_beam_depth_4_lenPen_0.3.json' +with open(results_path, 'r') as f: + results = json.load(f) + +all_answers = data['answers'] +all_questions = data['questions'] + + +dialogs = data['dialogs'] + +dialogs_dict = {} + +for dialog in dialogs: + image_id = dialog['image_id'] + for i, turn in enumerate(dialog['dialog']): + answer_opts = [all_answers[a] for a in turn['answer_options']] + dialogs_dict[str(image_id) + '_' + str(i+1)] = { + 'answer_opts': answer_opts, + 'gt_index': turn['gt_index'] + } + # print('bla') + +sparse_metrics = SparseGTMetrics() +ndcg = NDCG() + +# 1. Load a pretrained CrossEncoder model +model = CrossEncoder("cross-encoder/stsb-roberta-large") + +for i, (res_key, res) in enumerate(results.items()): + print('[INFO] {} / {}'.format(i+1, len(results))) + answer_opts = dialogs_dict[res_key]['answer_opts'] + gt_index = torch.tensor(dialogs_dict[res_key]['gt_index']) + gt_answer = answer_opts[gt_index] + sentence_combinations = [[res, opt] for opt in answer_opts] + scores = model.predict(sentence_combinations) + scores = torch.from_numpy(scores).unsqueeze(0).unsqueeze(0) + # scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0) + # scores = model.rank(res, answer_opts) + ranked_idx = scores_to_ranks(scores).squeeze().tolist() + new_order = np.argsort(ranked_idx) + # ranked_answers = [answer_opts[idx] for idx in new_order] + best_pick = answer_opts[new_order[0]] + sparse_metrics.observe(scores, gt_index) + if res_key in dense_data: + gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0) + ndcg.observe(scores.squeeze(0), gt_relevance) + + # print('bla') +print(sparse_metrics.retrieve()) +print(ndcg.retrieve()) + +# We want to compute the similarity between the query sentence... +# query = "A man is eating pasta." + +# # ... and all sentences in the corpus +# corpus = [ +# "A man is eating food.", +# "A man is eating a piece of bread.", +# "The girl is carrying a baby.", +# "A man is riding a horse.", +# "A woman is playing violin.", +# "Two men pushed carts through the woods.", +# "A man is riding a white horse on an enclosed ground.", +# "A monkey is playing drums.", +# "A cheetah is running behind its prey.", +# ] + +# # 2. We rank all sentences in the corpus for the query +# ranks = model.rank(query, corpus) + +# # Print the scores +# print("Query: ", query) +# for rank in ranks: +# print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}") +# """ +# Query: A man is eating pasta. +# 0.67 A man is eating food. +# 0.34 A man is eating a piece of bread. +# 0.08 A man is riding a horse. +# 0.07 A man is riding a white horse on an enclosed ground. +# 0.01 The girl is carrying a baby. +# 0.01 Two men pushed carts through the woods. +# 0.01 A monkey is playing drums. +# 0.01 A woman is playing violin. +# 0.01 A cheetah is running behind its prey. +# """ + +# # 3. Alternatively, you can also manually compute the score between two sentences +# import numpy as np + +# sentence_combinations = [[query, sentence] for sentence in corpus] +# scores = model.predict(sentence_combinations) + +# # Sort the scores in decreasing order to get the corpus indices +# ranked_indices = np.argsort(scores)[::-1] +# print("Scores:", scores) +# print("Indices:", ranked_indices) +# """ +# Scores: [0.6732372, 0.34102544, 0.00542465, 0.07569341, 0.00525378, 0.00536814, 0.06676237, 0.00534825, 0.00516717] +# Indices: [0 1 3 6 2 5 7 4 8] +# """ \ No newline at end of file diff --git a/generate_parallel_avsd.sh b/generate_parallel_avsd.sh new file mode 100755 index 0000000..2c2b0be --- /dev/null +++ b/generate_parallel_avsd.sh @@ -0,0 +1,71 @@ +# export MODEL=$1 +# export TAG=$2 +# export MODE=$3 +# export EVAL_DIR=$4 +# export MEDIUM=$5 +# export DSTC=$6 + +export MODEL='v2dial/stage_3' +export TAG='finetuned_no_experts_avsd' +export MODE='generate' +export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetune_without_experts_avsd' +export DSTC=7 + +# >>> conda initialize >>> +# !! Contents within this block are managed by 'conda init' !! +__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" +if [ $? -eq 0 ]; then + eval "$__conda_setup" +else + if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then + . "/opt/anaconda3/etc/profile.d/conda.sh" + else + export PATH="/opt/anaconda3/bin:$PATH" + fi +fi +unset __conda_setup +# <<< conda initialize <<< + +conda activate v2dial + +if [ $DSTC -eq 10 ]; then + export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0112 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0112 --end_idx_gen 0224 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 0224 --end_idx_gen 0336 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 0336 --end_idx_gen 0448 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 0448 --end_idx_gen 0560 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 0560 --end_idx_gen 0672 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 0672 --end_idx_gen 0784 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 0784 --end_idx_gen 0896 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0896 --end_idx_gen 1008 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1008 --end_idx_gen 1120 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 1120 --end_idx_gen 1232 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 1232 --end_idx_gen 1344 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 1344 --end_idx_gen 1456 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 1456 --end_idx_gen 1568 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 1568 --end_idx_gen 1680 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 1680 --end_idx_gen 1804 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +else + export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0107 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0107 --end_idx_gen 0214 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 0214 --end_idx_gen 0321 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 0321 --end_idx_gen 0428 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 0428 --end_idx_gen 0535 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 0535 --end_idx_gen 0642 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 0642 --end_idx_gen 0749 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 0749 --end_idx_gen 0856 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0856 --end_idx_gen 0963 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 0963 --end_idx_gen 1070 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 1070 --end_idx_gen 1177 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 1177 --end_idx_gen 1284 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 1284 --end_idx_gen 1391 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 1391 --end_idx_gen 1498 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 1498 --end_idx_gen 1605 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 1605 --end_idx_gen 1710 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +fi + +# export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 00 --end_idx_gen 10 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +# export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 10 --end_idx_gen 20 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + +wait +python merge_pred_avsd.py --dstc $DSTC diff --git a/generate_parallel_nextqa.sh b/generate_parallel_nextqa.sh new file mode 100755 index 0000000..d90e318 --- /dev/null +++ b/generate_parallel_nextqa.sh @@ -0,0 +1,51 @@ +# export MODEL=$1 +# export TAG=$2 +# export MODE=$3 +# export EVAL_DIR=$4 + +export MODEL='v2dial/stage_3' +export TAG='nextqa_with_test_captions' +export MODE='generate' +export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-from_stage1_only_nextqa_after_avsd_4_frames_3_rounds_ft_fp16' + +# >>> conda initialize >>> +# !! Contents within this block are managed by 'conda init' !! +__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" +if [ $? -eq 0 ]; then + eval "$__conda_setup" +else + if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then + . "/opt/anaconda3/etc/profile.d/conda.sh" + else + export PATH="/opt/anaconda3/bin:$PATH" + fi +fi +unset __conda_setup +# <<< conda initialize <<< + +conda activate v2dial + +# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 10 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 10 --end_idx_gen 20 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + + +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0000 --end_idx_gen 0573 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 0573 --end_idx_gen 1146 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1146 --end_idx_gen 1719 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 1719 --end_idx_gen 2292 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 2292 --end_idx_gen 2865 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 2865 --end_idx_gen 3438 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 3438 --end_idx_gen 4011 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 4011 --end_idx_gen 4584 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 4584 --end_idx_gen 5157 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 5157 --end_idx_gen 5730 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 5730 --end_idx_gen 6303 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 6303 --end_idx_gen 6876 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 6876 --end_idx_gen 7449 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 7449 --end_idx_gen 8022 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 8022 --end_idx_gen 8495 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 8495 --end_idx_gen 9178 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + +wait + +python merge_pred_nextqa.py \ No newline at end of file diff --git a/generate_parallel_visdial.sh b/generate_parallel_visdial.sh new file mode 100755 index 0000000..e18e070 --- /dev/null +++ b/generate_parallel_visdial.sh @@ -0,0 +1,67 @@ +# export MODEL=$1 +# export TAG=$2 +# export MODE=$3 +# export EVAL_DIR=$4 +# export MEDIUM=$5 +# export DSTC=$6 + +export MODEL='v2dial/stage_3' +export TAG='finetuned_visdial_from_scratch' +export MODE='generate' +export EVAL_DIR='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetuned_visdial_from_scratch/' + +# >>> conda initialize >>> +# !! Contents within this block are managed by 'conda init' !! +__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" +if [ $? -eq 0 ]; then + eval "$__conda_setup" +else + if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then + . "/opt/anaconda3/etc/profile.d/conda.sh" + else + export PATH="/opt/anaconda3/bin:$PATH" + fi +fi +unset __conda_setup +# <<< conda initialize <<< + +conda activate v2dial +# export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 00000 --end_idx_gen 10 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +# export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 00010 --end_idx_gen 20 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + + +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 00000 --end_idx_gen 00645 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 00645 --end_idx_gen 01290 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 01290 --end_idx_gen 01935 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 01935 --end_idx_gen 02580 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 02580 --end_idx_gen 03225 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 03225 --end_idx_gen 03870 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 03870 --end_idx_gen 04515 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 04515 --end_idx_gen 05160 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 05160 --end_idx_gen 05805 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 05805 --end_idx_gen 06450 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 06450 --end_idx_gen 07095 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 07095 --end_idx_gen 07740 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 07740 --end_idx_gen 08385 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 08385 --end_idx_gen 09030 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 09030 --end_idx_gen 09675 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 09675 --end_idx_gen 10320 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 10320 --end_idx_gen 10965 --gen_subset_num 17 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 10965 --end_idx_gen 11610 --gen_subset_num 18 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 11610 --end_idx_gen 12255 --gen_subset_num 19 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 12255 --end_idx_gen 12900 --gen_subset_num 20 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 12900 --end_idx_gen 13545 --gen_subset_num 21 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 13545 --end_idx_gen 14190 --gen_subset_num 22 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 14190 --end_idx_gen 14835 --gen_subset_num 23 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 14835 --end_idx_gen 15480 --gen_subset_num 24 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=0; python main_stage_3.py --start_idx_gen 15480 --end_idx_gen 16125 --gen_subset_num 25 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=1; python main_stage_3.py --start_idx_gen 16125 --end_idx_gen 16770 --gen_subset_num 26 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=2; python main_stage_3.py --start_idx_gen 16770 --end_idx_gen 17415 --gen_subset_num 27 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=3; python main_stage_3.py --start_idx_gen 17415 --end_idx_gen 18060 --gen_subset_num 28 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=4; python main_stage_3.py --start_idx_gen 18060 --end_idx_gen 18705 --gen_subset_num 29 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=5; python main_stage_3.py --start_idx_gen 18705 --end_idx_gen 19350 --gen_subset_num 30 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=6; python main_stage_3.py --start_idx_gen 19350 --end_idx_gen 19995 --gen_subset_num 31 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ +export CUDA_VISIBLE_DEVICES=7; python main_stage_3.py --start_idx_gen 19995 --end_idx_gen 20640 --gen_subset_num 32 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \ + +wait +python eval_visdial.py diff --git a/main_stage_1.py b/main_stage_1.py new file mode 100644 index 0000000..292b236 --- /dev/null +++ b/main_stage_1.py @@ -0,0 +1,177 @@ + +import argparse + + +import torch + +import torch.multiprocessing as mp +import torch.nn as nn +import torch.distributed as dist +# from transformers import BartTokenizer +from torch.utils.data import ConcatDataset + +from utils.init import initialize_from_env +# from datasets.pretraining import load_datasets, VideoTextRetDataset +# from datasets.utils import get_datasets_media +from models.setup import setup_model, setup_data, setup_data_test +from tasks.pre_train import pre_train +# from tasks.ft_avsd import ft_avsd, generate +# from tasks.stage_2_3 import pretrain +# from tasks.stage_2 import train as train_stage_2 + +# torch.autograd.set_detect_anomaly(True) + +parser = argparse.ArgumentParser(description='Main script for v2dial') +parser.add_argument( + '--model', + type=str, + default='v2dial/stage_1', + help='model name to train or test') + +parser.add_argument( + '--mode', + type=str, + default='train', + help='train, generate or debug' + ) + +parser.add_argument( + '--eval_dir', + type=str, + default='/scratch/abdessaied/projects/V2Dial_TU/logs/stage_4/v2dial-flant5_large_bert_experts_4_only_gen_AVSD' +) + +parser.add_argument( + '--wandb_mode', + type=str, + default='online', + choices=['online', 'offline', 'disabled', 'run', 'dryrun'] +) + +parser.add_argument( + '--wandb_project', + type=str, + default='V2Dial' +) + +parser.add_argument( + '--tag', + type=str, + # default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD', + # default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed', + # default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16', + default='without_sep_spatial_temporal_experts', + # default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs', + help="Tag to differentiate the models" +) + +parser.add_argument( + '--medium', + type=str, + default='avsd', + help="Medium of the test dataset" +) + +parser.add_argument( + '--start_idx_gen', + type=int, + default=0, + help="The start index for generation" +) + +parser.add_argument( + '--end_idx_gen', + type=int, + default=10, + help="The end index for generation" +) + +parser.add_argument( + '--gen_subset_num', + type=int, + default=1, + help="The index of the test split for generation" +) + +parser.add_argument('--ssh', action='store_true', + help='whether or not we are executing command via ssh. ' + 'If set to True, we will not log.info anything to screen and only redirect them to log file') + + +def main(gpu, config, args): + + config['gpu'] = gpu + if config['distributed']: + dist.init_process_group( + backend='nccl', + world_size=config['num_gpus'], + rank=gpu + ) + torch.cuda.set_device(gpu) + + device = torch.device(f'cuda:{gpu}') + if config.use_cpu: + device = torch.device('cpu') + config['device'] = device + # model = V2Dial(config) + + # config['num_training_steps'] = num_step_per_epoch * config['epochs'] + # config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs'] + if config['training']: + train_dataloaders, val_dataloaders = setup_data(config) + + ( + model, + model_without_ddp, + optimizer, + scheduler, + scaler, + start_epoch, + global_step, + webvid_step, + cc3m_step, + config + ) = setup_model(config, pretrain=True) + + pre_train( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + webvid_step, + cc3m_step, + scheduler, + scaler, + start_epoch, + config + ) + + if config['distributed']: + dist.destroy_process_group() + +if __name__ == '__main__': + args = parser.parse_args() + + # initialization + model, stage = args.model.split('/') + config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag) + config['wandb_enabled'] = args.wandb_mode == 'online' + config['training'] = args.mode == 'train' + config['generating'] = args.mode == 'generate' + config['debugging'] = args.mode == 'debug' + + config['wandb_mode'] = args.wandb_mode + config['medium'] = args.medium + config['start_idx_gen'] = args.start_idx_gen + config['end_idx_gen'] = args.end_idx_gen + + # config['wandb_project'] + # if config['accelerator'] == 'ddp': + if config['num_gpus'] > 1: + config['distributed'] = True + mp.spawn(main, nprocs=config['num_gpus'], args=(config, args)) + else: + config['distributed'] = False + main(0, config, args) diff --git a/main_stage_2.py b/main_stage_2.py new file mode 100644 index 0000000..6295416 --- /dev/null +++ b/main_stage_2.py @@ -0,0 +1,186 @@ + +import argparse + + +import torch + +import torch.multiprocessing as mp +import torch.nn as nn +import torch.distributed as dist +# from transformers import BartTokenizer +from torch.utils.data import ConcatDataset + +from utils.init import initialize_from_env +# from datasets.pretraining import load_datasets, VideoTextRetDataset +# from datasets.utils import get_datasets_media +from models.setup import setup_model, setup_data, setup_data_test +# from tasks.ft_avsd import ft_avsd, generate +from tasks.stage_2 import train as train_stage_2 + + +parser = argparse.ArgumentParser(description='Main script for v2dial') +parser.add_argument( + '--model', + type=str, + default='v2dial/stage_2', + help='model name to train or test') + +parser.add_argument( + '--mode', + type=str, + default='train', + help='train, generate or debug' + ) + +parser.add_argument( + '--eval_dir', + type=str, + default='/scratch/abdessaied/projects/V2Dial_TU/logs/stage_4/v2dial-flant5_large_bert_experts_4_only_gen_AVSD' +) + +parser.add_argument( + '--wandb_mode', + type=str, + default='online', + choices=['online', 'offline', 'disabled', 'run', 'dryrun'] +) + +parser.add_argument( + '--wandb_project', + type=str, + default='V2Dial' +) + +parser.add_argument( + '--tag', + type=str, + # default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD', + # default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed', + # default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16', + default='from_stage_1_only_gen_loss_frozen_llm', + # default='blub', + # default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs', + help="Tag to differentiate the models" +) + +parser.add_argument( + '--medium', + type=str, + default='avsd', + help="Medium of the test dataset" +) + +parser.add_argument( + '--start_idx_gen', + type=int, + default=0, + help="The start index for generation" +) + +parser.add_argument( + '--end_idx_gen', + type=int, + default=10, + help="The end index for generation" +) + +parser.add_argument( + '--gen_subset_num', + type=int, + default=1, + help="The index of the test split for generation" +) + +parser.add_argument('--ssh', action='store_true', + help='whether or not we are executing command via ssh. ' + 'If set to True, we will not log.info anything to screen and only redirect them to log file') + + +def main(gpu, config, args): + + config['gpu'] = gpu + if config['distributed']: + dist.init_process_group( + backend='nccl', + world_size=config['num_gpus'], + rank=gpu + ) + torch.cuda.set_device(gpu) + + device = torch.device(f'cuda:{gpu}') + if config.use_cpu: + device = torch.device('cpu') + config['device'] = device + # model = V2Dial(config) + + # config['num_training_steps'] = num_step_per_epoch * config['epochs'] + # config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs'] + if config['training']: + train_dataloaders, val_dataloaders = setup_data(config) + + ( + model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, config + ) = setup_model(config) + + if config['training']: + train_stage_2( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + scheduler, + scaler, + start_epoch, + config + ) + + # if config['stage'] == 'stage_3': + # ( + # model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, config + # ) = setup_model(config) + # if config['training']: + # ft_avsd( + # model, + # model_without_ddp, + # train_dataloaders, + # val_dataloaders, + # optimizer, + # global_step, + # scheduler, + # scaler, + # start_epoch, + # config + # ) + # elif config['generating']: + # test_dataloader = setup_data_test(config, args) + # generate(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num) + + if config['distributed']: + dist.destroy_process_group() + +if __name__ == '__main__': + args = parser.parse_args() + + # initialization + model, stage = args.model.split('/') + config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag) + config['wandb_enabled'] = args.wandb_mode == 'online' + config['training'] = args.mode == 'train' + config['generating'] = args.mode == 'generate' + config['debugging'] = args.mode == 'debug' + + config['wandb_mode'] = args.wandb_mode + config['medium'] = args.medium + config['start_idx_gen'] = args.start_idx_gen + config['end_idx_gen'] = args.end_idx_gen + + # config['wandb_project'] + # if config['accelerator'] == 'ddp': + if config['num_gpus'] > 1: + config['distributed'] = True + mp.spawn(main, nprocs=config['num_gpus'], args=(config, args)) + else: + config['distributed'] = False + main(0, config, args) diff --git a/main_stage_3.py b/main_stage_3.py new file mode 100644 index 0000000..a63bcee --- /dev/null +++ b/main_stage_3.py @@ -0,0 +1,185 @@ + +import argparse + + +import torch + +import torch.multiprocessing as mp +import torch.nn as nn +import torch.distributed as dist +# from transformers import BartTokenizer +from torch.utils.data import ConcatDataset + +from utils.init import initialize_from_env +# from datasets.pretraining import load_datasets, VideoTextRetDataset +# from datasets.utils import get_datasets_media +from models.setup import setup_model, setup_data, setup_data_test +# from tasks.ft_avsd import ft_avsd, generate +from tasks.stage_3 import ft_avsd, generate, generate_nextqa, generate_visdial + +parser = argparse.ArgumentParser(description='Main script for v2dial') +parser.add_argument( + '--model', + type=str, + default='v2dial/stage_3', + help='model name to train or test') + +parser.add_argument( + '--mode', + type=str, + default='generate', + help='train, generate or debug' + ) + +parser.add_argument( + '--eval_dir', + type=str, + default='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetune_without_stc_stm_only_visdial' +) + +parser.add_argument( + '--wandb_mode', + type=str, + default='online', + choices=['online', 'offline', 'disabled', 'run', 'dryrun'] +) + +parser.add_argument( + '--wandb_project', + type=str, + default='V2Dial' +) + +parser.add_argument( + '--tag', + type=str, + default="finetuned_visdial_without_stm_stc", + # default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD', + # default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed', + # default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16', + # default='from_stage1_after_avsd_only_visdial_4_frames_10_rounds_ft', + # default='from_scratch_visdial', + # default='no_moes_div_st_from_scratch_only_avsd_4_frames_3_rounds_ft_fp16', + # default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs', + help="Tag to differentiate the models" +) + +# parser.add_argument( +# '--medium', +# type=str, +# default='avsd', +# help="Medium of the test dataset" +# ) + +parser.add_argument( + '--start_idx_gen', + type=int, + default=0, + help="The start index for generation" +) + +parser.add_argument( + '--end_idx_gen', + type=int, + default=10, + help="The end index for generation" +) + +parser.add_argument( + '--gen_subset_num', + type=int, + default=1, + help="The index of the test split for generation" +) + +parser.add_argument('--ssh', action='store_true', + help='whether or not we are executing command via ssh. ' + 'If set to True, we will not log.info anything to screen and only redirect them to log file') + + +def main(gpu, config, args): + + config['gpu'] = gpu + if config['distributed']: + dist.init_process_group( + backend='nccl', + world_size=config['num_gpus'], + rank=gpu + ) + torch.cuda.set_device(gpu) + + device = torch.device(f'cuda:{gpu}') + if config.use_cpu: + device = torch.device('cpu') + config['device'] = device + # model = V2Dial(config) + + # config['num_training_steps'] = num_step_per_epoch * config['epochs'] + # config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs'] + if config['training']: + train_dataloaders, val_dataloaders = setup_data(config) + + ( + model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, visdial_step, avsd_step, nextqa_step, config + ) = setup_model(config) + + if config['training']: + ft_avsd( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + visdial_step, + avsd_step, + nextqa_step, + scheduler, + scaler, + start_epoch, + config + ) + elif config['generating']: + test_dataloader = setup_data_test(config) + if config.media_test == 'avsd': + generate(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num) + if config.media_test == 'visdial': + generate_visdial(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num) + elif config.media_test == 'nextqa': + generate_nextqa(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num) + + if config['distributed']: + dist.destroy_process_group() + + +if __name__ == '__main__': + args = parser.parse_args() + + # initialization + model, stage = args.model.split('/') + config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag) + config['wandb_enabled'] = args.wandb_mode == 'online' + config['training'] = args.mode == 'train' + config['generating'] = args.mode == 'generate' + config['debugging'] = args.mode == 'debug' + + config['wandb_mode'] = args.wandb_mode + # config['medium'] = args.medium + config['start_idx_gen'] = args.start_idx_gen + config['end_idx_gen'] = args.end_idx_gen + config['expert_permutation'] = None + # config['expert_permutation'] = { + # 'spatial': 'history', + # 'temporal': 'temporal', + # 'caption': 'caption', + # 'history': 'spatial' + # } + + # config['wandb_project'] + # if config['accelerator'] == 'ddp': + if config['num_gpus'] > 1: + config['distributed'] = True + mp.spawn(main, nprocs=config['num_gpus'], args=(config, args)) + else: + config['distributed'] = False + main(0, config, args) diff --git a/merge_pred_avsd.py b/merge_pred_avsd.py new file mode 100644 index 0000000..c14d432 --- /dev/null +++ b/merge_pred_avsd.py @@ -0,0 +1,61 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser(description='Main script for MST-MIXER') +parser.add_argument( + '--dstc', + type=int, + default=7, + choices=[7, 8, 10], + help='DSTC challenge identifier') + +args = parser.parse_args() + +assert args.dstc in [7, 8, 10] +if args.dstc == 7: + output_dir = 'output/dstc7' + raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC7-AVSD.json' + +elif args.dstc == 8: + output_dir = 'output/dstc8' + raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC8-AVSD.json' +else: + output_dir = 'output/dstc10' + raw_data_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/AVSD/test_set4DSTC10-AVSD.json' + +with open(raw_data_path, 'r') as f: + raw_dialogs = json.load(f)['dialogs'] + +file_paths = os.listdir(output_dir) +file_paths = list(filter(lambda f: 'part' in f , file_paths)) +name = file_paths[0] +file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths)) + +dialogs = {} +for pth in file_paths: + with open(pth, 'r') as f: + data = json.load(f) + + for dialog in data['dialogs']: + vid_id = dialog['image_id'] + dialogs[vid_id] = dialog + # dialogs.extend(data['dialogs']) + os.remove(pth) + +# Now, re-establish the original order of the dialogs +res = [] +for dialog in raw_dialogs: + vid_id = dialog['image_id'] + res.append(dialogs[vid_id]) + +res = { + 'dialogs': res +} + +name = "".join(name.split('-')[:-1]) + '.json' +output_path = os.path.join(output_dir, name) +with open(output_path, 'w') as f: + json.dump(res, f, indent=4) + +print('[INFO] Files merged and saved in {}'.format(output_path)) \ No newline at end of file diff --git a/merge_pred_nextqa.py b/merge_pred_nextqa.py new file mode 100644 index 0000000..a4124fa --- /dev/null +++ b/merge_pred_nextqa.py @@ -0,0 +1,34 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser(description='Main script for MST-MIXER') + +args = parser.parse_args() + +output_dir = 'output/nextqa' + +file_paths = os.listdir(output_dir) +file_paths = list(filter(lambda f: 'part' in f , file_paths)) +name = file_paths[0] +file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths)) + +results = {} +for pth in file_paths: + with open(pth, 'r') as f: + data = json.load(f) + for video_id in data: + if video_id not in results: + results[video_id] = data[video_id] + else: + for qid in data[video_id]: + if qid not in results[video_id]: + results[video_id][qid] = data[video_id][qid] + os.remove(pth) + +name = "".join(name.split('-')[:-1]) + '.json' +output_path = os.path.join(output_dir, name) +with open(output_path, 'w') as f: + json.dump(results, f, indent=4) + +print('[INFO] Files merged and saved in {}'.format(output_path)) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/backbones/Qformer.py b/models/backbones/Qformer.py new file mode 100755 index 0000000..e71b123 --- /dev/null +++ b/models/backbones/Qformer.py @@ -0,0 +1,1216 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/models/backbones/__init__.py b/models/backbones/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/backbones/base_model.py b/models/backbones/base_model.py new file mode 100755 index 0000000..5da161f --- /dev/null +++ b/models/backbones/base_model.py @@ -0,0 +1,247 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from models.common.utils import get_abs_path, is_url +from omegaconf import OmegaConf + + +class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + @property + def device(self): + return list(self.parameters())[0].device + + def load_checkpoint(self, url_or_filename): + """ + Load from a finetuned checkpoint. + + This should expect no mismatch in the model keys and the checkpoint keys. + """ + + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + @classmethod + def from_pretrained(cls, model_type): + """ + Build a pretrained model from default configuration file, specified by model_type. + + Args: + - model_type (str): model type, specifying architecture and checkpoints. + + Returns: + - model (nn.Module): pretrained or finetuned model, depending on the configuration. + """ + model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model + model = cls.from_config(model_cfg) + + return model + + @classmethod + def default_config_path(cls, model_type): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def load_checkpoint_from_config(self, cfg, **kwargs): + """ + Load checkpoint as specified in the config file. + + If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. + When loading the pretrained model, each task-specific architecture may define their + own load_from_pretrained() method. + """ + load_finetuned = cfg.get("load_finetuned", True) + if load_finetuned: + finetune_path = cfg.get("finetuned", None) + assert ( + finetune_path is not None + ), "Found load_finetuned is True, but finetune_path is None." + self.load_checkpoint(url_or_filename=finetune_path) + else: + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + assert "Found load_finetuned is False, but pretrain_path is None." + self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + +class BaseEncoder(nn.Module): + """ + Base class for primitive encoders, such as ViT, TimeSformer, etc. + """ + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +class SharedQueueMixin: + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr : ptr + batch_size] = image_feats.T + self.text_queue[:, ptr : ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = concat_all_gather(idxs) + self.idx_queue[:, ptr : ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr + + +class MomentumDistilationMixin: + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data = param_m.data * self.momentum + param.data * ( + 1.0 - self.momentum + ) + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/models/backbones/beit/__init__.py b/models/backbones/beit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/backbones/beit/builder.py b/models/backbones/beit/builder.py new file mode 100644 index 0000000..bc66e3c --- /dev/null +++ b/models/backbones/beit/builder.py @@ -0,0 +1,107 @@ +import logging +import torch +from models.utils import (interpolate_pos_relative_bias_beit, + load_temp_embed_with_mismatch) + +logger = logging.getLogger(__name__) + + +def interpolate_pos_embed_beit(state_dict, new_model): + """interpolate the positional embeddings. + The spatial pe is relative and temporal pe is absolute. + additional temporal pe is padded with 0. + + Args: + state_dict (dict): The state_dict. + new_model (nn.Module): The created model. + + Returns: dict. The state_dict with updated positional embeddings. + + """ + state_dict = interpolate_pos_relative_bias_beit( + state_dict_old=state_dict, + state_dict_new=new_model.state_dict(), + patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape, + ) + # absolute temporal pos bias + temporal_pe_key = "beit.embeddings.temporal_position_embeddings" + if temporal_pe_key in state_dict: + logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}") + state_dict[temporal_pe_key] = load_temp_embed_with_mismatch( + temp_embed_old=state_dict[temporal_pe_key], + temp_embed_new=new_model.state_dict()[temporal_pe_key], + ) + return state_dict + +def extract_beit_from_vindlu(vindlu_state_dict): + beit_state_dict = {} + beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k] + for param_name in beit_param_names: + new_name = param_name.replace('vision_encoder.', '') + beit_state_dict[new_name] = vindlu_state_dict[param_name] + + return beit_state_dict + +def build_beit(model_config, image_res, checkpoint=False): + """build beit with configuration. + + Args: + config (dict): The configs for beit. + image_res (int): The image resolution. + checkpoint (bool): Whether to enable gradient checkpointing. + + Returns: nn.Module + + """ + from .st_beit import BeitConfig as config_cls + from .st_beit import BeitModel as model_cls + + + vindlu_state_dict = torch.load(model_config['vindlu_path'])['model'] + state_dict = extract_beit_from_vindlu(vindlu_state_dict) + model_config = model_config['beit_config_json'] + + logger.info( + f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}." + ) + # BEiT uses average pooled tokens instead of [CLS] used by other models + aux_kwargs = {"add_pooling_layer": True} + # tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs) + + + # tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs) + # state_dict = tmp_model.state_dict() + + # del tmp_model + + logger.info(f"Init new model with new image size {image_res}, and load weights.") + + # other_cfg = model_config.temporal_modeling + other_cfg = {} + + vit_config = config_cls.from_pretrained( + model_config['pretrained'], image_size=image_res, **other_cfg + ) + + # vit_config.update(model_config) + + model = model_cls(config=vit_config, **aux_kwargs) + + if checkpoint: + model.gradient_checkpointing_enable() + + # interpolate relative pos bias + state_dict = interpolate_pos_relative_bias_beit( + state_dict_old=state_dict, + state_dict_new=model.state_dict(), + patch_shape_new=model.embeddings.patch_embeddings.patch_shape, + ) + + # del prompt_bias_table + for k in list(state_dict.keys()): + if "prompt_bias_table" in k: + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + logger.info(msg) + return model diff --git a/models/backbones/beit/st_beit.py b/models/backbones/beit/st_beit.py new file mode 100644 index 0000000..ae38511 --- /dev/null +++ b/models/backbones/beit/st_beit.py @@ -0,0 +1,1752 @@ +# coding=utf-8 +# Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BEiT model.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import einops +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedLMOutput, + SemanticSegmenterOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import (find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +from models.utils import interpolate_temporal_pos_embed + +from ...modules.temporal_modelling import (X_CLIP, STAdapter, TemporalAttention, + TemporalS4, WindowTemporalAttention) + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "BeitConfig" +_FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224-pt22k" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/beit-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/beit-base-patch16-224", + # See all BEiT models at https://huggingface.co/models?filter=beit +] + + +class BeitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BEiT + [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 8092): + Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during + pre-training. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`): + Indices of the feature maps to use for semantic segmentation. + pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + + Example: + + ```python + >>> from transformers import BeitModel, BeitConfig + + >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration + >>> configuration = BeitConfig() + + >>> # Initializing a model from the beit-base-patch16-224-pt22k style configuration + >>> model = BeitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "beit" + + def __init__( + self, + vocab_size=8192, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + is_encoder_decoder=False, + image_size=224, + num_frames=1, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + out_indices=[3, 5, 7, 11], + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + semantic_loss_ignore_index=255, + + temporal_model_block="none", + temporal_model_position="last", + temporal_model_init_value=0.0, + temporal_model_config={}, + use_temporal_position_embedding=False, + add_k_prompts=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.out_indices = out_indices + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + self.temporal_model_block = temporal_model_block + self.temporal_model_config = temporal_model_config + self.temporal_model_position = temporal_model_position + self.temporal_model_init_value = temporal_model_init_value + self.use_temporal_position_embedding = use_temporal_position_embedding + self.add_k_prompts = add_k_prompts + self.num_frames = num_frames + + +@dataclass +class BeitModelOutputWithPooling(BaseModelOutputWithPooling): + """ + Class for outputs of [`BeitModel`]. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class BeitDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class BeitEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = BeitPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter( + torch.zeros(1, num_patches + 1, config.hidden_size) + ) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + if config.use_temporal_position_embedding: + self.temporal_position_embeddings = nn.parameter.Parameter( + torch.zeros(1, config.num_frames, 1, config.hidden_size) + ) + else: + self.temporal_position_embeddings = None + + if config.add_k_prompts > 0: + self.prompt_tokens = nn.parameter.Parameter( + torch.zeros(1, config.add_k_prompts, config.hidden_size) + ) + else: + self.prompt_tokens = None + + def forward( + self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input image patches. Shape: [B, T, C, H, W]. + + + """ + t = pixel_values.shape[1] + pixel_values = einops.rearrange(pixel_values, "b t c h w -> (b t) c h w") + + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() # [(b t) l c] + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + if self.prompt_tokens is not None: + prompt_tokens = self.prompt_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings, prompt_tokens), dim=1) + else: + embeddings = torch.cat((cls_tokens, embeddings), dim=1) # [B*T, L, C] + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = einops.rearrange(embeddings, "(b t) l c -> b t l c", t=t) + if self.temporal_position_embeddings is not None: + if t <= self.temporal_position_embeddings.shape[1]: + embeddings = embeddings + self.temporal_position_embeddings[:, :t] + else: + tpe = interpolate_temporal_pos_embed(self.temporal_position_embeddings, t) + embeddings = embeddings + tpe + + embeddings = self.dropout(embeddings) + + return embeddings + + +class BeitPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + + return embeddings + + +class BeitSelfAttention(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + if window_size: + self.relative_position_bias = BeitRelativePositionBias( + config, window_size=window_size + ) + else: + self.relative_position_bias = None + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Add relative position bias if present. + if self.relative_position_bias is not None: + attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class BeitSelfOutput(nn.Module): + """ + The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class BeitAttention(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.attention = BeitSelfAttention(config, window_size=window_size) + self.output = BeitSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BeitIntermediate(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class BeitOutput(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class TemporalAttentionBeit(nn.Module): + + """temporal attention using BeitAttention""" + + def __init__(self, config: BeitConfig): + """TODO: to be defined.""" + super().__init__() + + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = BeitAttention(config, window_size=None) + self.scale = nn.Parameter( + config.temporal_model_init_value * torch.ones((config.hidden_size)), + requires_grad=True, + ) + self.drop_path = BeitDropPath(config.drop_path_rate) + + def forward(self, hidden_states: torch.Tensor): + """forward function + + Args: + hidden_states (torch.Tensor): The input. Shape: [b,t,l,c] + + Returns: TODO + + """ + b = hidden_states.shape[0] + output = einops.rearrange(hidden_states, "b t l c -> (b l) t c") + output = self.layernorm_before(output) + output = self.attention(output) + output = einops.rearrange(output[0], "(b l) t c -> b t l c", b=b) + return hidden_states + self.drop_path(output[0]) * self.scale + + +class BeitLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__( + self, + config: BeitConfig, + window_size: Optional[tuple] = None, + drop_path_rate: float = 0.0, + ) -> None: + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BeitAttention(config, window_size=window_size) + self.intermediate = BeitIntermediate(config) + self.output = BeitOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.drop_path = ( + BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.temporal_model_position = config.temporal_model_position + + init_values = config.layer_scale_init_value + if init_values > 0: + self.lambda_1 = nn.Parameter( + init_values * torch.ones((config.hidden_size)), requires_grad=True + ) + self.lambda_2 = nn.Parameter( + init_values * torch.ones((config.hidden_size)), requires_grad=True + ) + else: + self.lambda_1, self.lambda_2 = None, None + + if config.temporal_model_block == "st_adapter": + self.temp_model = STAdapter(**config.temporal_model_config) + elif config.temporal_model_block == "timesformer": + self.temp_model = TemporalAttention(**config.temporal_model_config) + elif config.temporal_model_block == "s4": + self.temp_model = TemporalS4(**config.temporal_model_config) + elif config.temporal_model_block == "ta_beit": + self.temp_model = TemporalAttentionBeit(config) + elif config.temporal_model_block == "window_attention": + self.temp_model = WindowTemporalAttention(**config.temporal_model_config) + elif config.temporal_model_block == "xclip": + self.temp_model = X_CLIP(**config.temporal_model_config) + elif config.temporal_model_block == "none": + self.temp_model = None + else: + raise ValueError(f"not accepted temporal model: {config.temporal_model_block}") + + self.temporal_model_block = config.temporal_model_block + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + + b, t, l, c = hidden_states.shape + + if self.temporal_model_block == "xclip": + assert ( + self.temporal_model_position == "first" and self.config.add_k_prompts == 1 + ), "xclip must be put before the attention and add_k_prompts must be 1." + + if self.temp_model is not None and self.temporal_model_position == "first": + hidden_states = self.temp_model(hidden_states) + + hidden_states = einops.rearrange(hidden_states, "b t l c -> (b t) l c") + + self_attention_outputs = self.attention( + self.layernorm_before( + hidden_states + ), # in BEiT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + ) + attention_output = self_attention_outputs[0] + + # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + layer_output = einops.rearrange(layer_output, "(b t) l c -> b t l c", b=b) + + # apply temporal modeling block + if self.temp_model is not None and self.temporal_model_position == "last": + layer_output = self.temp_model(layer_output) + + outputs = (layer_output,) + outputs + + return outputs + + +class BeitRelativePositionBias(nn.Module): + def __init__(self, config: BeitConfig, window_size: tuple) -> None: + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, config.num_attention_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # add bias for prompts + k = config.add_k_prompts + self.k = k + if k > 0: + self.prompt_bias_table = nn.parameter.Parameter( + torch.zeros((2 + k) * k, config.num_attention_heads) + ) # k prompt-to-token, k token-to-prompt, k*k prompt-to-promt + else: + self.prompt_bias_table = None + + def forward(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + + k = self.k + if k > 0: + l = self.window_size[0] * self.window_size[1] + 1 + bias = torch.zeros(l + k, l + k, relative_position_bias.shape[-1]).to( + relative_position_bias.device + ) + bias[:l, :l] = relative_position_bias + bias[l:, :l] = self.prompt_bias_table[:k].view(k, 1, -1) # prompt to token + bias[:l, l:] = self.prompt_bias_table[k : 2 * k].view(1, k, -1) # token to prompt + bias[l:, l:] = self.prompt_bias_table[2 * k, :].view(k, k, -1) # prompt to prompt + + # bias[k:, k:] = relative_position_bias + # bias[:k, k:] = self.prompt_bias_table[:k].view(k, 1, -1) + # bias[k:, :k] = self.prompt_bias_table[k : 2 * k].view(1, k, -1) + # bias[:k, :k] = self.prompt_bias_table[2 * k :].view(k, k, -1) + else: + bias = relative_position_bias + + return bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BeitEncoder(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + if config.use_shared_relative_position_bias: + self.relative_position_bias = BeitRelativePositionBias( + config, window_size=window_size + ) + else: + self.relative_position_bias = None + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers) + ] + self.layer = nn.ModuleList( + [ + BeitLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + ) + for i in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + # all_hidden_states = all_hidden_states + ( + # einops.rearrange(hidden_states, "b t l c -> (b t) l c"), + # ) + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + use_reentrant=False, + ) + else: + relative_position_bias = ( + self.relative_position_bias() + if self.relative_position_bias is not None + else None + ) + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions, relative_position_bias + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + # hidden_states = einops.rearrange(hidden_states, "b t l c -> (b t) l c") + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class BeitPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BeitConfig + base_model_prefix = "beit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BeitEncoder): + module.gradient_checkpointing = value + + +BEIT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`BeitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BEIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`BeitFeatureExtractor`]. See + [`BeitFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", + BEIT_START_DOCSTRING, +) +class BeitModel(BeitPreTrainedModel): + def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None: + super().__init__(config) + self.config = config + + self.embeddings = BeitEmbeddings(config) + self.encoder = BeitEncoder( + config, window_size=self.embeddings.patch_embeddings.patch_shape + ) + + self.layernorm = ( + nn.Identity() + if config.use_mean_pooling + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.pooler = BeitPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BeitModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BeitModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # pixel_values: [bsz, nframes, c, h, w] + assert pixel_values.ndim == 5, logger.error( + f"input shape to st_beit: {pixel_values.shape}" + ) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos + ) # [bs, nframes, L, c] + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + # logger.info(f"sequence_output: {sequence_output.shape}. pooled_output: {pooled_output.shape}") + + if not return_dict: + head_outputs = ( + (sequence_output, pooled_output) + if pooled_output is not None + else (sequence_output,) + ) + return head_outputs + encoder_outputs[1:] + + return BeitModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class BeitPooler(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.num_prompts = config.add_k_prompts + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if config.use_mean_pooling + else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Shape: [B,T,L,C] + """ + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + # patch_tokens = hidden_states[:, 1 + self.num_prompts :, :] + if self.num_prompts > 0: + patch_tokens = hidden_states[:, :, 1 : -self.num_prompts, :] + else: + patch_tokens = hidden_states[:, :, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(2)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, :, 0] + + return pooled_output + + +@add_start_docstrings( + """Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting + visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT + predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you + will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""", + BEIT_START_DOCSTRING, +) +class BeitForMaskedImageModeling(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # Classifier head + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, logits = outputs.loss, outputs.logits + >>> list(logits.shape) + [1, 196, 8192] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.layernorm(sequence_output) + prediction_scores = self.lm_head(sequence_output[:, 1:]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final + hidden states of the patch tokens) e.g. for ImageNet. + """, + BEIT_START_DOCSTRING, +) +class BeitForImageClassification(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=True) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BeitConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + + return output + + +class BeitPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + BeitConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class BeitPyramidPoolingModule(nn.Module): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + pool_scales: Tuple[int, ...], + in_channels: int, + channels: int, + align_corners: bool, + ) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = BeitPyramidPoolingBlock( + pool_scale=pool_scale, in_channels=in_channels, channels=channels + ) + self.blocks.append(block) + self.add_module(str(i), block) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + ppm_outs = [] + for ppm in self.blocks: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class BeitUperHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://arxiv.org/abs/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = BeitPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = BeitConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = BeitConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [ + lateral_conv(encoder_hidden_states[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode="bilinear", + align_corners=self.align_corners, + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class BeitFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of + [FCNNet](https://arxiv.org/abs/1411.4038>). + + Args: + config (BeitConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + config: BeitConfig, + in_index: int = 2, + kernel_size: int = 3, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + BeitConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + ) + ) + for i in range(self.num_convs - 1): + convs.append( + BeitConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = BeitConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@add_start_docstrings( + """ + Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + BEIT_START_DOCSTRING, +) +class BeitForSemanticSegmentation(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # FPNs + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d( + config.hidden_size, config.hidden_size, kernel_size=2, stride=2 + ), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d( + config.hidden_size, config.hidden_size, kernel_size=2, stride=2 + ), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d( + config.hidden_size, config.hidden_size, kernel_size=2, stride=2 + ), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Semantic segmentation head(s) + self.decode_head = BeitUperHead(config) + self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + main_loss = loss_fct(upsampled_logits, labels) + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, BeitForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [ + feature + for idx, feature in enumerate(encoder_hidden_states) + if idx + 1 in self.config.out_indices + ] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :] + .permute(0, 2, 1) + .reshape(batch_size, -1, patch_resolution, patch_resolution) + for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + logits = self.decode_head(features) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/models/backbones/bert/__init__.py b/models/backbones/bert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/backbones/bert/builder.py b/models/backbones/bert/builder.py new file mode 100644 index 0000000..18271b0 --- /dev/null +++ b/models/backbones/bert/builder.py @@ -0,0 +1,71 @@ +from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel + + +def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'): + """build text encoder. + + Args: + model_config (dict): model config. + pretrain (bool): Whether to do pretrain or finetuning. + checkpoint (bool): whether to do gradient_checkpointing. + + Returns: TODO + + """ + bert_size = model_config['expert_size'] + bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}']) + # bert_config.encoder_width = model_config.vision_encoder.d_model + bert_config.gradient_checkpointing = checkpoint + bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)] + if expert_type=='modality': + if modality_type == 'vis': + bert_config.cross_attention_freq = 2 + else: + bert_config.cross_attention_freq = -1 + else: + bert_config.cross_attention_freq = 1 + + if pretrain: + text_encoder, loading_info = BertForMaskedLM.from_pretrained( + f'bert-{bert_size}-uncased', + config=bert_config, + output_loading_info=True, + ) + else: + text_encoder, loading_info = BertModel.from_pretrained( + f'bert-{bert_size}-uncased', + config=bert_config, + add_pooling_layer=True, + output_loading_info=True, + ) + + return text_encoder + + +def build_bert_decoder(model_config, checkpoint): + """build text decoder the same as the multimodal encoder. + + Args: + model_config (dict): model config. + pretrain (bool): Whether to do pretrain or finetuning. + checkpoint (bool): whether to do gradient_checkpointing. + + Returns: TODO + + """ + bert_config = BertConfig.from_json_file(model_config.text_encoder.config) + bert_config.encoder_width = model_config.vision_encoder.d_model + bert_config.gradient_checkpointing = checkpoint + + bert_config.fusion_layer = 0 + bert_config.num_hidden_layers = ( + bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer + ) + + text_decoder, loading_info = BertLMHeadModel.from_pretrained( + model_config.text_encoder.pretrained, + config=bert_config, + output_loading_info=True, + ) + + return text_decoder diff --git a/models/backbones/bert/tokenization_bert.py b/models/backbones/bert/tokenization_bert.py new file mode 100644 index 0000000..66e8d8e --- /dev/null +++ b/models/backbones/bert/tokenization_bert.py @@ -0,0 +1,546 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Bert.""" + + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", + "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", + "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", + "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", + "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", + "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", + "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", + "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", + "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", + "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", + "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", + "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", + "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", + "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", + "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", + "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", + "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", + "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "bert-base-uncased": 512, + "bert-large-uncased": 512, + "bert-base-cased": 512, + "bert-large-cased": 512, + "bert-base-multilingual-uncased": 512, + "bert-base-multilingual-cased": 512, + "bert-base-chinese": 512, + "bert-base-german-cased": 512, + "bert-large-uncased-whole-word-masking": 512, + "bert-large-cased-whole-word-masking": 512, + "bert-large-uncased-whole-word-masking-finetuned-squad": 512, + "bert-large-cased-whole-word-masking-finetuned-squad": 512, + "bert-base-cased-finetuned-mrpc": 512, + "bert-base-german-dbmdz-cased": 512, + "bert-base-german-dbmdz-uncased": 512, + "TurkuNLP/bert-base-finnish-cased-v1": 512, + "TurkuNLP/bert-base-finnish-uncased-v1": 512, + "wietsedv/bert-base-dutch-cased": 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + "bert-base-uncased": {"do_lower_case": True}, + "bert-large-uncased": {"do_lower_case": True}, + "bert-base-cased": {"do_lower_case": False}, + "bert-large-cased": {"do_lower_case": False}, + "bert-base-multilingual-uncased": {"do_lower_case": True}, + "bert-base-multilingual-cased": {"do_lower_case": False}, + "bert-base-chinese": {"do_lower_case": False}, + "bert-base-german-cased": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, + "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, + "bert-base-german-dbmdz-cased": {"do_lower_case": False}, + "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, + "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, + "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, + "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, +} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(PreTrainedTokenizer): + r""" + Construct a BERT tokenizer. Based on WordPiece. + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + Args: + vocab_file (:obj:`str`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to do basic tokenization before WordPiece. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs + ): + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + vocab_file) + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + - single sequence: ``[CLS] X `` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + "Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format( + vocab_file) + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer(object): + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + Args: + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + Args: + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + :func:`PreTrainedTokenizer.tokenize`) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union( + set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens diff --git a/models/backbones/bert/xbert.py b/models/backbones/bert/xbert.py new file mode 100644 index 0000000..494eff2 --- /dev/null +++ b/models/backbones/bert/xbert.py @@ -0,0 +1,2160 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from torch import Tensor, device, dtype, nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +# from transformers.models.bert.configuration_bert import BertConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.file_utils import (ModelOutput, add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, + MultipleChoiceModelOutput, NextSentencePredictorOutput, + QuestionAnsweringModelOutput, SequenceClassifierOutput, + TokenClassifierOutput) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import logging + +transformers.logging.set_verbosity_error() + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to + instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BERT + [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import BertModel, BertConfig + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + cross_module="ca", + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.cross_module = cross_module + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + # added `attention_scores` to return tuple + outputs = ( + (context_layer, attention_probs, attention_scores) + if output_attentions + else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + + self.self = BertSelfAttention(config, is_cross_attention) + + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] + return outputs # (context_layer, attention_probs, attention_scores, past_key_value,) + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = config.cross_attention_freq > 0 and layer_num % config.cross_attention_freq == 0 + + if self.has_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[ + (self.layer_num - self.config.fusion_layer) + % len(encoder_hidden_states) + ], + encoder_attention_mask[ + (self.layer_num - self.config.fusion_layer) + % len(encoder_hidden_states) + ], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + logger.info(f"build bert with cross_module: {config.cross_module}") + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + # mode="multi_modal", + normalize_attention=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + # if ( + # mode == "text" or mode == "temporal" + # ): # temporal is added and used for temporal att module. + # start_layer = 0 + # output_layer = self.config.fusion_layer + + # elif mode == "fusion": + # start_layer = self.config.fusion_layer + # output_layer = self.config.num_hidden_layers + + # elif mode == "multi_modal": + # start_layer = 0 + # output_layer = self.config.num_hidden_layers + + for i in range(len(self.layer)): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_reentrant=False, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + # whether to output normalized attention, + # note for unnormalized attention, there is a mask added + offset = int(normalize_attention) + # all_self_attentions = all_self_attentions + (layer_outputs[1], ) + all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],) + if hasattr(layer_module, "crossattention"): + # all_cross_attentions = all_cross_attentions + (layer_outputs[3], ) + all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.BertForPreTraining`. + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + # mode="multi_modal", + normalize_attention=True, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds or encoder_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # mode=mode, + normalize_attention=normalize_attention, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + ) + @replace_return_docstrings( + output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, + BERT_START_DOCSTRING, +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + ) + @replace_return_docstrings( + output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction="mean", + mode="multi_modal", + normalize_attention=True, + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + normalize_attention=normalize_attention, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, dim=-1 + ) + loss_distill = (loss_distill * (labels != -100)).sum(1) + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, **model_kwargs + ): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), + ) + return reordered_past + + +@dataclass +class MaskedLMOutputWithDistill(MaskedLMOutput): + loss_aux: Optional[torch.FloatTensor] = None + loss_distill: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING +) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def tie_aux_decoder_weights(self, module, aux_modules): + """Tie decoder weights of all `aux_modules` to `module`, (not bias)""" + for m in aux_modules: + m.predictions.decoder.weight = module.predictions.decoder.weight + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + # mode="multi_modal", + normalize_attention=True, + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + # mode=mode, + normalize_attention=normalize_attention, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + masked_lm_loss_aux = 0.0 + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1 + ) + loss_distill = loss_distill[labels != -100].mean() + masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + # changed from MaskedLMOutput to MaskedLMOutputWithDistill + return MaskedLMOutputWithDistill( + loss=masked_lm_loss, + loss_aux=masked_lm_loss_aux, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + last_hidden_state=outputs.last_hidden_state + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert ( + self.config.pad_token_id is not None + ), "The PAD token should be defined for generation" + attention_mask = torch.cat( + [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1 + ) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + ) + @replace_return_docstrings( + output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ( + ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + ) + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = ( + attention_mask.view(-1, attention_mask.size(-1)) + if attention_mask is not None + else None + ) + token_type_ids = ( + token_type_ids.view(-1, token_type_ids.size(-1)) + if token_type_ids is not None + else None + ) + position_ids = ( + position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + ) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, + labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels), + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/models/backbones/blip2.py b/models/backbones/blip2.py new file mode 100755 index 0000000..33707b6 --- /dev/null +++ b/models/backbones/blip2.py @@ -0,0 +1,268 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import contextlib +import logging +import os +import time +import datetime + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +# import .backbones.common.dist_utils as dist_utils +# from minigpt4.common.dist_utils import download_cached_file +# from minigpt4.common.utils import is_url +# from minigpt4.common.logger import MetricLogger + +from models.backbones.base_model import BaseModel +from models.backbones.Qformer import BertConfig, BertLMHeadModel +from models.backbones.eva_vit import create_eva_vit_g +from transformers import BertTokenizer + + +class Blip2Base(BaseModel): + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + ): + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + def get_optimizer_params(self, weight_decay, lr_scale=1): + + vit_num_layers = self.visual_encoder.get_num_layer() + lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2)) + + parameter_group_names = {} + parameter_group_vars = {} + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias"): + group_name = "no_decay" + this_weight_decay = 0. + else: + group_name = "decay" + this_weight_decay = weight_decay + if 'visual_encoder' in name: + layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.','')) + group_name = "vit_layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if layer_id is not None: + scale = lr_scales[layer_id] + else: + scale = 1 + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # import json + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + optim_params = list(parameter_group_vars.values()) + return optim_params + + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_feat = model.forward_text(text_input) + text_embed = F.normalize(model.text_proj(text_feat)) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + + vit_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(model.device) + image_feat, vit_feat = model.forward_image(image) + image_embed = model.vision_proj(image_feat) + image_embed = F.normalize(image_embed, dim=-1) + + vit_feats.append(vit_feat.cpu()) + image_embeds.append(image_embed) + + vit_feats = torch.cat(vit_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = [] + for image_embed in image_embeds: + sim_q2t = image_embed @ text_embeds.t() + sim_i2t, _ = sim_q2t.max(0) + sims_matrix.append(sim_i2t) + sims_matrix = torch.stack(sims_matrix, dim=0) + + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[topk_idx], + text_atts=text_atts[topk_idx], + ).float() + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[topk_idx.cpu()].to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[start + i].repeat(k_test, 1), + text_atts=text_atts[start + i].repeat(k_test, 1), + ).float() + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/models/backbones/blip2_outputs.py b/models/backbones/blip2_outputs.py new file mode 100755 index 0000000..e8722b1 --- /dev/null +++ b/models/backbones/blip2_outputs.py @@ -0,0 +1,110 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) + + +@dataclass +class BlipSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipIntermediateOutput(ModelOutput): + """ + Data class for intermediate outputs of BLIP models. + + image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). + text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). + + image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). + text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). + + encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. + encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. + + decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. + decoder_labels (torch.LongTensor): labels for the captioning loss. + + itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). + itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) + + """ + + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class BlipOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[BlipSimilarity] = None + + intermediate_output: BlipIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_lm: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipOutputFeatures(ModelOutput): + """ + Data class of features from BlipFeatureExtractor. + + Args: + image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional + image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional + text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional + text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/models/backbones/clip_vision_encoder.py b/models/backbones/clip_vision_encoder.py new file mode 100644 index 0000000..8518551 --- /dev/null +++ b/models/backbones/clip_vision_encoder.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + + +class CLIPVisionEncoder(nn.Module): + def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_encoder_name = encoder_name + # self.select_layer = args.mm_vision_select_layer + # self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + self.select_layer = -1 + self.select_feature = "patch" + if not delay_load: + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name) + + def load_model(self): + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name) + self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name) + self.vision_encoder.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.hidden_states[self.select_layer] + if self.select_feature == 'patch': + image_features = image_features[:, :] + elif self.select_feature == 'cls_patch': + image_features = image_features + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + # print("image feature shape", image_features.shape) + # print(type(image_forward_outs)) + # print(type(image_forward_outs.shape)) + # image_features = image_forward_outs.to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_encoder.dtype + + @property + def device(self): + return self.vision_encoder.device + + @property + def config(self): + if self.is_loaded: + return self.vision_encoder.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 \ No newline at end of file diff --git a/models/backbones/encoder_decoder/builder.py b/models/backbones/encoder_decoder/builder.py new file mode 100644 index 0000000..1aff26a --- /dev/null +++ b/models/backbones/encoder_decoder/builder.py @@ -0,0 +1,141 @@ + +import glog as logger +import re +import json + +from peft import LoraConfig, get_peft_model + +from .xflan_t5 import T5Config, T5ForConditionalGeneration +from .xbart import BartConfig, BartForConditionalGeneration, BartEncoder, BartForCausalLM + + +def build_encoder_decoder(model_config): + """build (encoder-) decoder model for answer generation. + + Args: + model_config (dict): model config. + + Returns: TODO + + """ + logger.info('[INFO] Loading Encoder Decoder [Type = {}]'.format(model_config['enc_dec_name'])) + + if model_config['enc_dec_family'] == 'flan_t5': + config_cls = T5Config + model_cls = T5ForConditionalGeneration + elif model_config['enc_dec_family'] == 'bart': + config_cls = BartConfig + if model_config['use_decoder_only']: + model_cls = BartForCausalLM + else: + model_cls = BartForConditionalGeneration + else: + raise ValueError('{} is not supported'.format(model_config['enc_dec_family'])) + enc_dec_config = config_cls.from_pretrained(model_config['enc_dec_name']) + model_config['enc_dec_dim'] = enc_dec_config.d_model + # enc_dec_config.encoder_layers = enc_dec_config.encoder_layers - model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])] + enc_dec = model_cls.from_pretrained( + model_config['enc_dec_name'], + config=enc_dec_config + ) + + # first_k = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])] + # enc_dec.model.encoder.remove_first_k_layers(first_k) + # get the last encoder layers + # enc_dec. + + + if model_config['use_lora_enc_dec']: + # load the lora config + with open(model_config['lora_config'], 'r') as f: + lora_config = json.load(f) + + # get the linear layer to perform LoRA on + model_modules = str(enc_dec.modules) + pattern = r'\((\w+)\): Linear' + linear_layer_names = re.findall(pattern, model_modules) + + names = [] + # Print the names of the Linear layers + for name in linear_layer_names: + names.append(name) + target_modules = list(set(names)) + + lora_config['target_modules'] = target_modules + + lora_config = LoraConfig(**lora_config) + + enc_dec = get_peft_model(enc_dec, lora_config) + + return enc_dec + + +def build_encoder(model_config, expert_type, modality=None): + """build (encoder-) decoder model for answer generation. + + Args: + model_config (dict): model config. + + Returns: TODO + + """ + log_txt = '[INFO] Loading {} Expert'.format(expert_type) + if modality is not None: + log_txt += ' [Modality = {}]'.format(modality) + log_txt += ' [Type = {}]'.format(model_config['enc_dec_name']) + + logger.info(log_txt) + + if model_config['enc_dec_family'] == 'flan_t5': + config_cls = T5Config + model_cls = T5ForConditionalGeneration + elif model_config['enc_dec_family'] == 'bart': + config_cls = BartConfig + model_cls = BartEncoder + else: + raise ValueError('{} is not supported'.format(model_config['enc_dec_family'])) + + config = config_cls.from_pretrained(model_config['enc_dec_name']) + config.modality_expert_layers = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])] + config.grounding_expert_layers = model_config['num_layers_grounding_expert_{}'.format(model_config['enc_dec_family'])] + + model_config['enc_dec_dim'] = config.d_model + + expert = model_cls.from_pretrained( + model_config['enc_dec_name'], + config=config, + expert_type=expert_type, + modality=modality + ) + + if model_config['use_lora_expert']: + # load the lora config + with open(model_config['lora_config'], 'r') as f: + lora_config = json.load(f) + + # get the linear layer to perform LoRA on + model_modules = str(expert.modules) + pattern = r'\((\w+)\): Linear' + linear_layer_names = re.findall(pattern, model_modules) + + names = [] + # Print the names of the Linear layers + for name in linear_layer_names: + names.append(name) + target_modules = list(set(names)) + + lora_config['target_modules'] = target_modules + + lora_config = LoraConfig(**lora_config) + + expert = get_peft_model(expert, lora_config) + + # expert = model_cls( + # config=config, + # expert_type=expert_type, + # modality=modality + # ) + + return expert + + diff --git a/models/backbones/encoder_decoder/builder_orig.py b/models/backbones/encoder_decoder/builder_orig.py new file mode 100644 index 0000000..119ef0e --- /dev/null +++ b/models/backbones/encoder_decoder/builder_orig.py @@ -0,0 +1,65 @@ +from .xflan_t5 import T5Config, T5ForConditionalGeneration +from .xbart_original import BartConfig, BartForConditionalGeneration, BartEncoder + +import glog as logger + + +def build_encoder_decoder(model_config): + """build (encoder-) decoder model for answer generation. + + Args: + model_config (dict): model config. + + Returns: TODO + + """ + logger.info('[INFO] Loading Encoder Decoder: {}'.format(model_config['enc_dec_name'])) + + if model_config['enc_dec_family'] == 'flan_t5': + config_cls = T5Config + model_cls = T5ForConditionalGeneration + elif model_config['enc_dec_family'] == 'bart': + config_cls = BartConfig + model_cls = BartForConditionalGeneration + else: + raise ValueError('{} is not supported'.format(model_config['enc_dec_family'])) + config = config_cls.from_pretrained(model_config['enc_dec_name']) + model_config['enc_dec_dim'] = config.d_model + enc_dec = model_cls.from_pretrained( + model_config['enc_dec_name'], + config=config + ) + + return enc_dec + + +def build_encoder(model_config): + """build (encoder-) decoder model for answer generation. + + Args: + model_config (dict): model config. + + Returns: TODO + + """ + logger.info('[INFO] Loading Expert as Encoder of {}'.format(model_config['enc_dec_name'])) + + if model_config['enc_dec_family'] == 'flan_t5': + config_cls = T5Config + model_cls = T5ForConditionalGeneration + elif model_config['enc_dec_family'] == 'bart': + config_cls = BartConfig + model_cls = BartEncoder + else: + raise ValueError('{} is not supported'.format(model_config['enc_dec_family'])) + + config = config_cls.from_pretrained(model_config['enc_dec_name']) + model_config['enc_dec_dim'] = config.d_model + config.encoder_layers = model_config['num_layers_modality_expert'] + + expert = model_cls.from_pretrained( + model_config['enc_dec_name'], + config=config + ) + + return expert diff --git a/models/backbones/encoder_decoder/outputs.py b/models/backbones/encoder_decoder/outputs.py new file mode 100644 index 0000000..7330ec0 --- /dev/null +++ b/models/backbones/encoder_decoder/outputs.py @@ -0,0 +1,19 @@ +from typing import Optional, Tuple +import torch +from transformers.modeling_outputs import ModelOutput +from dataclasses import dataclass + + +@dataclass +class Seq2SeqV2DialOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None + diff --git a/models/backbones/encoder_decoder/xbart.py b/models/backbones/encoder_decoder/xbart.py new file mode 100644 index 0000000..0183a4a --- /dev/null +++ b/models/backbones/encoder_decoder/xbart.py @@ -0,0 +1,2044 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BART model.""" +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.bart.configuration_bart import BartConfig +from .outputs import Seq2SeqV2DialOutput + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-base" +_CONFIG_FOR_DOC = "BartConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" +_SEQ_CLASS_EXPECTED_LOSS = 0.0 +_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" + +# QuestionAsnwering docstring +_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" +_QA_EXPECTED_LOSS = 0.59 +_QA_EXPECTED_OUTPUT = "' nice puppet'" + + +BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/bart-large", + # see all BART models at https://huggingface.co/models?filter=bart +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class BartEncoderLayer(nn.Module): + def __init__(self, config: BartConfig, has_cross_att=False): + super().__init__() + self.embed_dim = config.d_model + self.has_cross_att = has_cross_att + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + if self.has_cross_att: + self.cross_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + key_value_states: Optional[torch.FloatTensor] = None, + cross_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + key_value_states=key_value_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + ######################################## + if self.has_cross_att: + assert cross_hidden_states is not None + assert cross_attention_mask is not None + residual = hidden_states + hidden_states, attn_weights, _ = self.cross_attn( + hidden_states=hidden_states, + key_value_states=cross_hidden_states, + attention_mask=cross_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.cross_attn_layer_norm(hidden_states) + ######################################## + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BartDecoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BartPreTrainedModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BartDecoder, BartEncoder)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class PretrainedBartModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartPretrainedModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +BART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BART_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['not', 'good', 'healthy', 'great', 'very'] + ``` +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is useful if you want more control over how to + convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class BartEncoder(BartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BartEncoderLayer`]. + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None, expert_type: Optional[str] = None, modality: Optional[str] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.expert_type = expert_type + self.modality = modality + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + + + # self.modality_expert_layers = None + # if hasattr(config, 'modality_expert_layers'): + # encoder_layers = config.modality_expert_layers + + if self.expert_type is None: + encoder_layers = config.encoder_layers + self.cross_att_every = encoder_layers + 1 # No cross attention + + elif self.expert_type == 'modality': + encoder_layers = config.modality_expert_layers + if self.modality in ['spatial', 'temporal']: + self.cross_att_every = 2 # Cross attention every two layers + else: + self.cross_att_every = encoder_layers + 1 # No cross attention + + + elif self.expert_type == 'grounding': + encoder_layers = config.grounding_expert_layers + self.cross_att_every = 1 # Cross attention at every layer + + layers = [] + for i in range(encoder_layers): + has_cross_att = i % self.cross_att_every == 0 + if self.cross_att_every > encoder_layers: + has_cross_att = False + layers.append(BartEncoderLayer(config, has_cross_att=has_cross_att)) + + self.layers = nn.ModuleList(layers) + # self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + # self.grounding_expert_layers = range(config.modality_expert_layers, config.encoder_layers) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def remove_first_k_layers(self, first_k): + assert first_k < len(self.layers) and first_k > 0 + self.layers = self.layers[first_k:] + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cross_embeds: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + # expert_type: Optional[str] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + if cross_attention_mask is not None: + cross_attention_mask = _expand_mask(cross_attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.size(1)) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + cross_hidden_states=cross_embeds, + cross_attention_mask=cross_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_hidden_states=cross_embeds, + cross_attention_mask=cross_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BartDecoder(BartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare BART Model outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class BartModel(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig): + super().__init__(config, embed=None) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqV2DialOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_outputs=encoder_outputs + ) + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING +) +class BartForConditionalGeneration(BartPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(BART_GENERATION_EXAMPLE) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqV2DialOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + encoder_outputs=outputs.encoder_outputs + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + BART_START_DOCSTRING, +) +class BartForSequenceClassification(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BartModel(config) + self.classification_head = BartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BART_START_DOCSTRING, +) +class BartForQuestionAnswering(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BartModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_QA, + output_type=Seq2SeqQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=_QA_EXPECTED_LOSS, + expected_output=_QA_EXPECTED_OUTPUT, + ) + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class BartDecoderWrapper(BartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + """ + BART decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings). + """, + BART_START_DOCSTRING, +) +class BartForCausalLM(BartPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = BartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if input_ids is None and inputs_embeds is None: + input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/models/backbones/encoder_decoder/xbart_original.py b/models/backbones/encoder_decoder/xbart_original.py new file mode 100644 index 0000000..b6de4ff --- /dev/null +++ b/models/backbones/encoder_decoder/xbart_original.py @@ -0,0 +1,1954 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BART model.""" +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.bart.configuration_bart import BartConfig +from .outputs import Seq2SeqV2DialOutput + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-base" +_CONFIG_FOR_DOC = "BartConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" +_SEQ_CLASS_EXPECTED_LOSS = 0.0 +_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" + +# QuestionAsnwering docstring +_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" +_QA_EXPECTED_LOSS = 0.59 +_QA_EXPECTED_OUTPUT = "' nice puppet'" + + +BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/bart-large", + # see all BART models at https://huggingface.co/models?filter=bart +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class BartEncoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BartDecoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BartPreTrainedModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BartDecoder, BartEncoder)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class PretrainedBartModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartPretrainedModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +BART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BART_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['not', 'good', 'healthy', 'great', 'very'] + ``` +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is useful if you want more control over how to + convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class BartEncoder(BartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BartEncoderLayer`]. + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BartDecoder(BartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare BART Model outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class BartModel(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqV2DialOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_outputs=encoder_outputs + ) + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING +) +class BartForConditionalGeneration(BartPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(BART_GENERATION_EXAMPLE) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + # loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqV2DialOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + encoder_outputs=outputs.encoder_outputs + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + BART_START_DOCSTRING, +) +class BartForSequenceClassification(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BartModel(config) + self.classification_head = BartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BART_START_DOCSTRING, +) +class BartForQuestionAnswering(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BartModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_QA, + output_type=Seq2SeqQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=_QA_EXPECTED_LOSS, + expected_output=_QA_EXPECTED_OUTPUT, + ) + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class BartDecoderWrapper(BartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + """ + BART decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings). + """, + BART_START_DOCSTRING, +) +class BartForCausalLM(BartPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = BartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past \ No newline at end of file diff --git a/models/backbones/encoder_decoder/xflan_t5.py b/models/backbones/encoder_decoder/xflan_t5.py new file mode 100644 index 0000000..db3ab99 --- /dev/null +++ b/models/backbones/encoder_decoder/xflan_t5.py @@ -0,0 +1,2075 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config +from .outputs import Seq2SeqV2DialOutput + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = hidden_states + self.dropout(attention_output[0]) + + if torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + + if torch.isinf(layer_output).any(): + clamp_value = torch.finfo(layer_output.dtype).max - 1000 + layer_output = torch.clamp(layer_output, min=-clamp_value, max=clamp_value) + + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if ( + hidden_states.dtype == torch.float16 + and torch.isinf(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["T5Block"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING +) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + if reduction == "none": + loss = loss.view(lm_logits.size(0), -1).sum(1) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqV2DialOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_outputs=encoder_outputs + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder.embed_tokens.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5EncoderModel + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5EncoderModel.from_pretrained("t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/models/backbones/eva_vit.py b/models/backbones/eva_vit.py new file mode 100755 index 0000000..84e3fc3 --- /dev/null +++ b/models/backbones/eva_vit.py @@ -0,0 +1,455 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from models.common.dist_utils import download_cached_file + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + def get_num_layer(self, var_name=""): + if var_name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("rel_pos_bias"): + return len(self.blocks) - 1 + elif var_name.startswith("blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + else: + return len(self.blocks) + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + # depth = 37, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) +# print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/models/backbones/mini_gpt4_llama_v2.py b/models/backbones/mini_gpt4_llama_v2.py new file mode 100755 index 0000000..175fbeb --- /dev/null +++ b/models/backbones/mini_gpt4_llama_v2.py @@ -0,0 +1,895 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model +# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub + +from transformers import LlamaTokenizer +from transformers import BitsAndBytesConfig + +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, + set_peft_model_state_dict, +) +import time +import numpy as np + +from minigpt4.models import policies + + +@registry.register_model("mini_gpt4_llama_v2") +class MiniGPT4_llama_v2(Blip2Base): + """ + BLIP2 GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/minigpt4.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + low_resource=False, # use 8 bit and put vit in cpu + end_sym='\n', + lora_r = 8, + lora_target_modules = ["q_proj","v_proj"], + lora_alpha=16, + # lora_r = 16, + # lora_target_modules = ["q_proj","v_proj","v_proj"], + lora_dropout= 0.05, + ckpt_path = "", + system_prompt= False, + chat_template=False, + token_pooling=True, + use_grad_checkpoint_llm=False, + max_context_len=3800, + remove_template = False, + + ): + super().__init__() + if "Mistral" in llama_model: + from minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model + print("Mistral model") + self.model_type = "Mistral" + else: + from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model + print("Llama model") + self.model_type = "Llama" + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + self.token_pooling = token_pooling + self.remove_template = remove_template + + print("token pooling", self.token_pooling) + + + self.use_grad_checkpoint_llm = use_grad_checkpoint_llm + self.max_context_len = max_context_len + self.chat_template = chat_template + + # print('Loading VIT') + # self.visual_encoder, self.ln_vision = self.init_vision_encoder( + # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + # ) + + if freeze_vit: + # vit_precision="fp32" + print("vit precision", vit_precision) + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print("freeze the vision encoder") + + else: + vit_precision="fp32" + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + + print("unfreeze the vision encoder") + + print('Loading VIT Done') + + # print("visual encoder shape", self.visual_encoder.pos_embed.shape) + # assert False + + print('Loading LLAMA') + + + self.B_SYS, self.E_SYS = "<>\n", "\n<>\n\n" + + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model,use_fast=False) # + self.llama_tokenizer.pad_token = "$$" + + self.system_prompt = system_prompt + + + + print("self.low_resource",self.low_resource) + if self.low_resource: + self.llama_model = llm_model.from_pretrained( + llama_model, + torch_dtype=torch.float16, + # torch_dtype = torch.bfloat16, + load_in_8bit=True, + # device_map = "balanced" + # device_map="auto", + device_map={'':torch.cuda.current_device()}, + # device_map={'':0} + + ) + # bnb_config = BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_use_double_quant=True, + # bnb_4bit_quant_type="nf4", + # bnb_4bit_compute_dtype=torch.bfloat16, + # ) + # self.llama_model = llm_model.from_pretrained( + # llama_model, + # torch_dtype=torch.bfloat16, + # device_map={'':torch.cuda.current_device()}, + # quantization_config=bnb_config, + # ) + else: + self.llama_model = llm_model.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + + + + # self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) + self.llama_model = prepare_model_for_int8_training(self.llama_model) + + + + loraconfig = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + self.llama_model = get_peft_model(self.llama_model, loraconfig) + + # if ckpt_path: + # print('load the llm under lora') + # ckpt = torch.load(ckpt_path) + # set_peft_model_state_dict(self.llama_model,ckpt) + + + + self.llama_model.print_trainable_parameters() + + if self.use_grad_checkpoint_llm: + self.llama_model.gradient_checkpointing_enable() + + # if not self.low_resource: + # for name, param in self.llama_model.named_parameters(): + # if "embed_token" in name: + # param.data = param.data.float() + # param.requires_grad = True + + + print('Loading LLAMA Done') + + + if self.token_pooling: + self.llama_proj = nn.Linear( + 1408*4, self.llama_model.config.hidden_size + ) + else: + self.llama_proj = nn.Linear( + 1408, self.llama_model.config.hidden_size + ) + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + def encode_img(self, image): + device = image.device + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + bs, pn, hs = image_embeds.shape + if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632) + image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + def get_context_emb(self, prompt, img_list): + img_device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + + mixed_embs = torch.cat(mixed_embs, dim=1) + # # truncate the length of tokens to the max context window + # mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + # mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1) + # # check if the number of token in the second dimention is more than the max context window then truncate it + # context_window=self.max_context_len-seg_embs[-1].shape[1] + # if mixed_embs_without_instruction.shape[1] > context_window : + # mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window] + # mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1) + # print("mixed_embs",mixed_embs.shape) + + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + + else: + # return the multi-modal embedding in right padding + emb_lists = [] + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + # print("p_embed device",p_tokens.input_ids.device) + # print("p_tokens",img_embeds.device) + # print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device) + p_embed = self.embed_tokens(p_tokens.input_ids) + + # print("model device",self.llama_model.get_device()) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1)) + + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + # max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + max_length = self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + + return wrapped_embs, wrapped_atts + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + + input_lens = [] + cat_embs = [] + cat_atts = [] + + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + # print('===================================') + # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones]) + # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2]) + # print('check out emb: ', output_embs[i][:2]) + # print('check out pad emb: ', output_embs[i][-2:]) + # print('+++++++++++++++++++++++++++++++++++') + # + # print('check attn before: ', input_atts[i][:this_input_ones]) + # print('check attn after: ', input_atts[i][this_input_ones:]) + # print('check attn gt before: ', output_atts[i][:3]) + # print('check attn gt after: ', output_atts[i][-3:]) + + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def get_conv_emb(self, conv_q, conv_a, conv_img): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + regress_embs_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + assigned_imgs = conv_img[batch_idx] + questions = [self.prompt_wrap( + img_embeds=img, + atts_img=None, + prompts=[q], + lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)] + q_embs = [emb for emb, _ in questions] + + answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers] + cur_emb = [] + cur_target = [] + for i in range(len(questions)): + cur_emb.append(q_embs[i]) + cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100) + + cur_emb.append(self.embed_tokens(answers[i].input_ids)) + cur_target.append(answers[i].input_ids) + + cur_emb = torch.cat(cur_emb, dim=1) + cur_target = torch.cat(cur_target, dim=1) + + regress_embs_list.append(cur_emb) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + + regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device) + regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device) + targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100 + + for batch_idx in range(batch_size): + cur_len = regress_embs_list[batch_idx].shape[1] + regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len] + regress_attn[batch_idx, :cur_len] = 1 + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + return regress_embeds, regress_attn, targets + + def preparing_embedding(self, samples): + def remove_special_tokens(data): + + # if "instruction_input" in data: + data = [instruct.replace(" [caption]","") for instruct in data] + data = [instruct.replace(" [vqa]","") for instruct in data] + data = [instruct.replace(" [grounding]","") for instruct in data] + data = [instruct.replace(" [identify]","") for instruct in data] + data = [instruct.replace(" [refer]","") for instruct in data] + return data + + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + # print("img_embeds shape",img_embeds.shape) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + conv_img = assign_imgs(conv_q, img_embeds) + + if self.chat_template: + conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] + + regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img) + cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0] + + else: + instruction = samples["instruction_input"] if "instruction_input" in samples else None + + # print("instruction before", instruction) + if self.remove_template: + instruction = remove_special_tokens(instruction) + # print("instruction after", instruction) + + if self.chat_template: + instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction="mean"): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + print("inputs_embeds shape",inputs_embeds.shape) + print("cond_embeds shape",cond_embeds.shape) + print("regress_embeds shape",regress_embeds.shape) + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = attention_mask[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + # print length of instruction_input and answer words + # for i in range (len(samples["instruction_input"])): + # print("instruction_input length",len(samples["instruction_input"][i].split(" "))) + # print("answer length",len(samples["answer"][i].split(" "))) + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + print("targets shape",targets.shape) + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + images, + texts, + use_nucleus_sampling=False, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + lengths=None, + return_video_temporal_features=False, + img_embeds=None, + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + if img_embeds is None: + img_embeds, atts_img = self.encode_img(images.to(self.device)) + else: + # Use images features from the input(4,45,64,5632) + img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:]) + img_embeds= img_embeds.to(self.device) + img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) + atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device) + + print("img_embeds shape",img_embeds.shape) + if lengths is not None: + image_lists = [] + img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1]) + for idx, img_embed in enumerate(img_embeds): + image_lists.append([img_embed[i][None] for i in range(lengths[idx])]) + else: + image_lists = [[image_emb[None]] for image_emb in img_embeds] + assert len(texts) == len(image_lists) + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + # print("inputs_embeds shape",embs.shape) + # print("attention_mask shape",attn_mask.shape) + # check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window + if self.model_type == "Llama": + context_window = 3700 + else: + context_window = 7500 + if embs.shape[1] > context_window: + embs = embs[:, -context_window:] + attn_mask = attn_mask[:, -context_window:] + print("inputs_embeds shape",embs.shape) + print("attention_mask shape",attn_mask.shape) + with self.maybe_autocast(): + if return_video_temporal_features: + last_hidden_state = self.llama_model( + inputs_embeds=embs, + attention_mask=attn_mask, + output_hidden_states=True, + ).hidden_states[-1] + video_temporal_features = last_hidden_state.mean(dim=1) + # normalize the temporal features using L2 norm + # video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True) + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + temperature=temperature, + repetition_penalty=repetition_penalty, + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + if return_video_temporal_features: + return answers, video_temporal_features + else: + return answers + + @torch.no_grad() + def generate_text_only( + self, + images, + seg_tokens, + use_nucleus_sampling=False, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + lengths=None, + return_video_temporal_features=False, + img_embeds=None, + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + # seg_tokens=[] + # for i, text in enumerate(texts): + # seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids) + + batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens] + + # seg_embs = torch.cat(seg_embs, dim=1) + # print("seg_embs shape",seg_embs.shape) + # batch_embs=[seg_embs] + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + + print("inputs_embeds shape",embs.shape) + print("attention_mask shape",attn_mask.shape) + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + temperature=temperature, + repetition_penalty=repetition_penalty, + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + return answers + + + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for open-ended VQA + ''' + images = samples["image"].cuda() + texts = samples["instruction_input"] + + output_text = self.generate( + images=images, + texts=texts, + num_beams=num_beams, + max_new_tokens=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: + output_text = self._lemmatize(output_text) + + return output_text + + def predict_class( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=5, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for multi-choice VQA + ''' + + image = samples["image"].cuda() + instruction = samples['instruction_input'] + answers = samples["choices"] + num_cand = samples["num_choices"] + + ranks = self.multi_select(image, instruction, answers, num_cand) + + pred_ans = [] + for i, rank in enumerate(ranks): + pred = answers[rank[0]][i] + pred_ans.append(pred) + return pred_ans + + def embed_tokens(self, token_ids): + try: + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + except AttributeError: + embeds = self.llama_model.model.embed_tokens(token_ids) + + return embeds + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r",64) + lora_alpha = cfg.get("lora_alpha",16) + chat_template = cfg.get("chat_template",False) + system_prompt = cfg.get("system_prompt", False) + token_pooling = cfg.get("token_pooling",True) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + remove_template = cfg.get("remove_template", False) + + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r = lora_r, + lora_alpha = lora_alpha, + chat_template = chat_template, + system_prompt = system_prompt, + token_pooling = token_pooling, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + remove_template = remove_template + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model + + +def assign_imgs(batched_instruct_list, batched_img_embeds): + '''this function is used when the data is interleaved. + the interlevaed data is separated, and this function assign + corresponding image embeddings to each segment''' + if len(batched_img_embeds.shape) == 3: + batched_img_embeds = batched_img_embeds[:, None] + + batched_assigned = [] + + for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds): + img_idx = 0 + assigned_img = [] + n_assigned = [] + for instruct in instruct_list: + n_img = instruct.count('') + if n_img > 0: # this instruction include images. + assigned_img.append(img_embeds[None, img_idx:img_idx+n_img]) + img_idx += n_img + n_assigned.append(n_img) + else: # this instruction doesn't include images + assigned_img.append(None) + n_assigned.append(None) + batched_assigned.append(assigned_img) + + return batched_assigned \ No newline at end of file diff --git a/models/backbones/mini_gpt4v.py b/models/backbones/mini_gpt4v.py new file mode 100755 index 0000000..c3d5e4d --- /dev/null +++ b/models/backbones/mini_gpt4v.py @@ -0,0 +1,709 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub + +from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig + +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training +) +import time +import numpy as np + +from minigpt4.models import policies + + +@registry.register_model("mini_gpt4v") +class MiniGPT4v(Blip2Base): + """ + BLIP2 GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/minigpt4.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + low_resource=False, # use 8 bit and put vit in cpu + end_sym='\n', + lora_r = 8, + lora_target_modules = ["q_proj","v_proj"], + lora_alpha=16, + # lora_r = 16, + # lora_target_modules = ["q_proj","v_proj","v_proj"], + lora_dropout= 0.05, + ckpt_path = "", + system_prompt= False, + chat_template=False, + token_pooling=True, + use_grad_checkpoint_llm=False, + max_context_len=3800, + remove_template = False, + + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + self.token_pooling = token_pooling + self.remove_template = remove_template + + print("token pooling", self.token_pooling) + + + self.use_grad_checkpoint_llm = use_grad_checkpoint_llm + self.max_context_len = max_context_len + self.chat_template = chat_template + + # print('Loading VIT') + # self.visual_encoder, self.ln_vision = self.init_vision_encoder( + # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + # ) + + + print("vit precision", vit_precision) + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision + ) + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print("freeze the vision encoder") + + + print('Loading VIT Done') + + # print("visual encoder shape", self.visual_encoder.pos_embed.shape) + # assert False + + print('Loading LLAMA') + + + self.B_SYS, self.E_SYS = "<>\n", "\n<>\n\n" + + if 'CodeLlama' in llama_model: + self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) # + self.llama_tokenizer.pad_token = "$$" + else: + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) # + self.llama_tokenizer.pad_token = "$$" + + self.system_prompt = system_prompt + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + + + + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + quantization_config=bnb_config, + device_map={"": 0} + ) + + # self.llama_model.gradient_checkpointing_enable() + self.llama_model = prepare_model_for_kbit_training(self.llama_model) + + # self.llama_model.print_trainable_parameters() + + + print('Loading LLAMA Done') + + self.merge_n = 3 + + self.llama_proj = nn.Linear( + 1408 * self.merge_n**2, self.llama_model.config.hidden_size + ) + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + def encode_img(self, image): + device = image.device + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + bs, ch, w, h = image.shape + assert w % 224 == 0 + bw = w // 224 + assert h % 224 == 0 + bh = h // 224 + image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224 + image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224) + + with self.maybe_autocast(): + image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device) + + image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1]) + image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs + image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1]) + + bs, pn, hs = image_embeds.shape + + image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + def get_context_emb(self, prompt, img_list): + img_device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + + else: + # return the multi-modal embedding in right padding + emb_lists = [] + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1)) + + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + + return wrapped_embs, wrapped_atts + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + + input_lens = [] + cat_embs = [] + cat_atts = [] + + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + # print('===================================') + # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones]) + # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2]) + # print('check out emb: ', output_embs[i][:2]) + # print('check out pad emb: ', output_embs[i][-2:]) + # print('+++++++++++++++++++++++++++++++++++') + # + # print('check attn before: ', input_atts[i][:this_input_ones]) + # print('check attn after: ', input_atts[i][this_input_ones:]) + # print('check attn gt before: ', output_atts[i][:3]) + # print('check attn gt after: ', output_atts[i][-3:]) + + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def get_conv_emb(self, conv_q, conv_a, conv_img): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + regress_embs_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + assigned_imgs = conv_img[batch_idx] + questions = [self.prompt_wrap( + img_embeds=img, + atts_img=None, + prompts=[q], + lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)] + q_embs = [emb for emb, _ in questions] + + answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers] + cur_emb = [] + cur_target = [] + for i in range(len(questions)): + cur_emb.append(q_embs[i]) + cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100) + + cur_emb.append(self.embed_tokens(answers[i].input_ids)) + cur_target.append(answers[i].input_ids) + + cur_emb = torch.cat(cur_emb, dim=1) + cur_target = torch.cat(cur_target, dim=1) + + regress_embs_list.append(cur_emb) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + + regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device) + regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device) + targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100 + + for batch_idx in range(batch_size): + cur_len = regress_embs_list[batch_idx].shape[1] + regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len] + regress_attn[batch_idx, :cur_len] = 1 + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + return regress_embeds, regress_attn, targets + + def preparing_embedding(self, samples): + def remove_special_tokens(data): + + # if "instruction_input" in data: + data = [instruct.replace(" [caption]","") for instruct in data] + data = [instruct.replace(" [vqa]","") for instruct in data] + data = [instruct.replace(" [grounding]","") for instruct in data] + data = [instruct.replace(" [identify]","") for instruct in data] + data = [instruct.replace(" [refer]","") for instruct in data] + return data + + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + conv_img = assign_imgs(conv_q, img_embeds) + + if self.chat_template: + conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] + + regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img) + cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0] + + else: + instruction = samples["instruction_input"] if "instruction_input" in samples else None + + # print("instruction before", instruction) + if self.remove_template: + instruction = remove_special_tokens(instruction) + # print("instruction after", instruction) + + if self.chat_template: + instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction="mean"): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = attention_mask[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + images, + texts, + use_nucleus_sampling=False, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + lengths=None, + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + img_embeds, atts_img = self.encode_img(images.to(self.device)) + if lengths is not None: + image_lists = [] + img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1]) + for idx, img_embed in enumerate(img_embeds): + image_lists.append([img_embed[i][None] for i in range(lengths[idx])]) + else: + image_lists = [[image_emb[None]] for image_emb in img_embeds] + assert len(texts) == len(image_lists) + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + + return answers + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for open-ended VQA + ''' + images = samples["image"].cuda() + texts = samples["instruction_input"] + + output_text = self.generate( + images=images, + texts=texts, + num_beams=num_beams, + max_new_tokens=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: + output_text = self._lemmatize(output_text) + + return output_text + + def predict_class( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=5, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for multi-choice VQA + ''' + + image = samples["image"].cuda() + instruction = samples['instruction_input'] + answers = samples["choices"] + num_cand = samples["num_choices"] + + ranks = self.multi_select(image, instruction, answers, num_cand) + + pred_ans = [] + for i, rank in enumerate(ranks): + pred = answers[rank[0]][i] + pred_ans.append(pred) + return pred_ans + + def embed_tokens(self, token_ids): + try: + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + except AttributeError: + embeds = self.llama_model.model.embed_tokens(token_ids) + + return embeds + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r",64) + lora_alpha = cfg.get("lora_alpha",16) + chat_template = cfg.get("chat_template",False) + system_prompt = cfg.get("system_prompt", False) + token_pooling = cfg.get("token_pooling",True) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + remove_template = cfg.get("remove_template", False) + + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r = lora_r, + lora_alpha = lora_alpha, + chat_template = chat_template, + system_prompt = system_prompt, + token_pooling = token_pooling, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + remove_template = remove_template + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model + + +def assign_imgs(batched_instruct_list, batched_img_embeds): + '''this function is used when the data is interleaved. + the interlevaed data is separated, and this function assign + corresponding image embeddings to each segment''' + if len(batched_img_embeds.shape) == 3: + batched_img_embeds = batched_img_embeds[:, None] + + batched_assigned = [] + + for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds): + img_idx = 0 + assigned_img = [] + n_assigned = [] + for instruct in instruct_list: + n_img = instruct.count('') + if n_img > 0: # this instruction include images. + assigned_img.append(img_embeds[None, img_idx:img_idx+n_img]) + img_idx += n_img + n_assigned.append(n_img) + else: # this instruction doesn't include images + assigned_img.append(None) + n_assigned.append(None) + batched_assigned.append(assigned_img) + + return batched_assigned \ No newline at end of file diff --git a/models/backbones/mistral.py b/models/backbones/mistral.py new file mode 100644 index 0000000..43095ff --- /dev/null +++ b/models/backbones/mistral.py @@ -0,0 +1,25 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +device = "cuda" # the device to load the model onto + +model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") + +messages = [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} +] +p="Well, I'm quite partial to a good squeeze of fresh lemon juice." +encoded_input = tokenizer(p, return_tensors='pt') +embeds = model.model.embed_tokens(encoded_input.input_ids) +print(embeds.shape) + + +encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt") +model_inputs = encodeds.to(device) +model.to(device) + +generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True) +decoded = tokenizer.batch_decode(generated_ids) +print(decoded[0]) diff --git a/models/backbones/modeling_llama_v2.py b/models/backbones/modeling_llama_v2.py new file mode 100644 index 0000000..3043af0 --- /dev/null +++ b/models/backbones/modeling_llama_v2.py @@ -0,0 +1,112 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig + + +class LlamaForCausalLM(LlamaForCausalLMOrig): + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/models/backbones/modeling_llama_v3.py b/models/backbones/modeling_llama_v3.py new file mode 100644 index 0000000..3043af0 --- /dev/null +++ b/models/backbones/modeling_llama_v3.py @@ -0,0 +1,112 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig + + +class LlamaForCausalLM(LlamaForCausalLMOrig): + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/models/backbones/modeling_mistral.py b/models/backbones/modeling_mistral.py new file mode 100644 index 0000000..3a98c7d --- /dev/null +++ b/models/backbones/modeling_mistral.py @@ -0,0 +1,1388 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.models.mistral.configuration_mistral import MistralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# TODO @Arthur no longer copied from LLama after static cache +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, +} + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/models/backbones/moes.py b/models/backbones/moes.py new file mode 100644 index 0000000..3f0914b --- /dev/null +++ b/models/backbones/moes.py @@ -0,0 +1,287 @@ +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from timm.models.layers import DropPath + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if mask is not None: + # if mask.dim() != x.dim(): + # expanded_mask = mask[:, None, None, :].expand(B, 1, N, N) + # else: + # expanded_mask = mask + mask = mask.bool() + attn = attn.masked_fill(~mask, float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class MoELayer(nn.Module): + def __init__( + self, + dim, + num_heads, + expert_type, + use_sep_spatial_temp_experts=True, + has_hist=False, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=LlamaRMSNorm, + ): + super().__init__() + self.has_hist = has_hist + self.use_sep_spatial_temp_experts = use_sep_spatial_temp_experts + self.norm_att = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + mlp_hidden_dim = int(dim * mlp_ratio) + + if expert_type == 'modalities': + # EXPERT CONSTRUCTION + if use_sep_spatial_temp_experts: + # Spatial expert + self.norm_spatial = norm_layer(dim) + self.mlp_spatial = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Temporal expert + self.norm_temp = norm_layer(dim) + self.mlp_temp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Vis expert + self.norm_vis = norm_layer(dim) + self.mlp_vis = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # caption expert + self.norm_cap = norm_layer(dim) + self.mlp_cap = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # history expert + if has_hist: + self.norm_hist = norm_layer(dim) + self.mlp_hist = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + elif expert_type == 'fusion': + # Fusion expert + self.norm_fusion = norm_layer(dim) + self.mlp_fusion = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + else: + raise ValueError + + + def forward(self, x, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len=None, is_vid=False, mask=None, only_text=False, expert_permutation=None): + + if self.has_hist: + assert hist_feat_len is not None + + x_shortcut, attn = self.attn(self.norm_att(x), mask=mask) + x = x + self.drop_path(x_shortcut) + len_init = x.size(1) + # bs, h_dim = x.size(0), x.size(-1) + # device = x.device + # if only_text: + # # end_idx_caption = special_toks_indices.get('', special_toks_indices[''] + 1) + # # x = x[:, special_toks_indices['']: end_idx_caption, :] + # x = x + self.drop_path(self.mlp_cap(self.norm_cap(x))) + + if expert_flag == 'modalities': + if self.use_sep_spatial_temp_experts: + x_spatial = x[:, :vis_feat_len] + if expert_permutation is not None: + if expert_permutation['spatial'] == 'temporal': + x_spatial = x_spatial + self.drop_path(self.mlp_temp(self.norm_temp(x_spatial))) + elif expert_permutation['spatial'] == 'caption': + x_spatial = x_spatial + self.drop_path(self.mlp_cap(self.norm_cap(x_spatial))) + elif expert_permutation['spatial'] == 'history': + x_spatial = x_spatial + self.drop_path(self.mlp_hist(self.norm_hist(x_spatial))) + elif expert_permutation['spatial'] == 'spatial': + x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial))) + x_vis = x_spatial + + else: + x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial))) + x_vis = x_spatial + + if is_vid: + x_temporal = x[:, vis_feat_len:2*vis_feat_len] + if expert_permutation is not None: + if expert_permutation['temporal'] == 'spatial': + x_temporal = x_temporal + self.drop_path(self.mlp_spatial(self.norm_spatial(x_temporal))) + elif expert_permutation['temporal'] == 'caption': + x_temporal = x_temporal + self.drop_path(self.mlp_cap(self.norm_cap(x_temporal))) + elif expert_permutation['temporal'] == 'history': + x_temporal = x_temporal + self.drop_path(self.mlp_hist(self.norm_hist(x_temporal))) + elif expert_permutation['temporal'] == 'temporal': + x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal))) + else: + x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal))) + x_vis = torch.concat([x_spatial, x_temporal], dim=1) + x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis))) + else: + x_vis = x[:, :vis_feat_len] + x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis))) + + if self.has_hist: + x_caption = x[:, -(cap_feat_len + hist_feat_len): -hist_feat_len] + if expert_permutation is not None: + if expert_permutation['caption'] == 'spatial': + x_caption = x_caption + self.drop_path(self.mlp_spatial(self.norm_spatial(x_caption))) + elif expert_permutation['caption'] == 'temporal': + x_caption = x_caption + self.drop_path(self.mlp_temp(self.norm_temp(x_caption))) + elif expert_permutation['caption'] == 'history': + x_caption = x_caption + self.drop_path(self.mlp_hist(self.norm_hist(x_caption))) + elif expert_permutation['caption'] == 'caption': + x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption))) + else: + x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption))) + + + x_history = x[:, -hist_feat_len:] + if expert_permutation is not None: + if expert_permutation['history'] == 'spatial': + x_history = x_history + self.drop_path(self.mlp_spatial(self.norm_spatial(x_history))) + elif expert_permutation['history'] == 'temporal': + x_history = x_history + self.drop_path(self.mlp_temp(self.norm_temp(x_history))) + elif expert_permutation['history'] == 'caption': + x_history = x_history + self.drop_path(self.mlp_cap(self.norm_cap(x_history))) + elif expert_permutation['history'] == 'history': + x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history))) + else: + x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history))) + # concat the features back + x = torch.cat([x_vis, x_caption, x_history], dim=1) + else: + x_caption = x[:, -cap_feat_len:] + x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption))) + x = torch.cat([x_vis, x_caption], dim=1) + + assert x.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format( + x.size(1), len_init + ) + + elif expert_flag == 'fusion': + x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x))) + + return x + + +class Pooler(nn.Module): + def __init__(self, hidden_size): + super(Pooler, self).__init__() + + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + pooled_states = hidden_states[:, 0] + pooled_output = self.dense(pooled_states) + pooled_output = self.activation(pooled_output) + return pooled_output \ No newline at end of file diff --git a/models/backbones/moes_huggingface.py b/models/backbones/moes_huggingface.py new file mode 100644 index 0000000..d3e6e45 --- /dev/null +++ b/models/backbones/moes_huggingface.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from timm.models.layers import DropPath +import warnings +from torch import Tensor +from typing import Optional, Tuple + + +from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig + +class MoELayer(nn.Module): + def __init__(self, config, expert_type): + super(MoELayer, self).__init__() + self.config = config + self.expert_type = expert_type + self.bert_config = BertConfig.from_pretrained('bert-large-uncased') + + # Shared across all experts + self.attention = BertAttention(self.bert_config) + + # One for each expert + if expert_type == 'modalities': + # Spatial expert + self.intermediate_spatial = BertIntermediate(self.bert_config) + self.output_spatial = BertOutput(self.bert_config) + + # Temporal expert + self.intermediate_temporal = BertIntermediate(self.bert_config) + self.output_temporal = BertOutput(self.bert_config) + + # Vis Expert + self.intermediate_vis = BertIntermediate(self.bert_config) + self.output_vis = BertOutput(self.bert_config) + + # Caption Expert + self.intermediate_caption = BertIntermediate(self.bert_config) + self.output_caption = BertOutput(self.bert_config) + + if config.stage != 'stage_1': + # History Expert + self.intermediate_history = BertIntermediate(self.bert_config) + self.output_history = BertOutput(self.bert_config) + + # Fusion expert + elif expert_type == 'fusion': + self.intermediate_fusion = BertIntermediate(self.bert_config) + self.output_fusion = BertOutput(self.bert_config) + else: + raise ValueError + + self._init_weights() + + def _init_weights(self): + for _, m in dict(self.named_modules()).items(): + if isinstance(m, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range) + elif isinstance(m, nn.LayerNorm): + m.bias.data.zero_() + m.weight.data.fill_(1.0) + if isinstance(m, nn.Linear) and m.bias is not None: + m.bias.data.zero_() + + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.bert_config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + # if self.config.is_decoder: + # extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + # input_shape, attention_mask, device + # ) + # else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + + def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False): + + input_shape = hidden_states.size()[:-1] + # dtype = mask.dtype + # device = mask.device + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + head_mask=None + ) + attention_output = self_attention_outputs[0] + # outputs = self_attention_outputs[1:] + + len_init = attention_output.size(1) + # bs, h_dim = x.size(0), x.size(-1) + # device = x.device + + + if expert_flag == 'modalities': + if only_text: + intermediate_output = self.intermediate_caption(attention_output) + layer_output = self.output_caption(intermediate_output, attention_output) + else: + # split the input first into different parts/modalities + unchanged = attention_output[:, :special_toks_indices[''], :] + end_idx_spatial = special_toks_indices.get('', special_toks_indices['']) + attention_spatial = attention_output[:, special_toks_indices['']:end_idx_spatial, :] + + end_idx_caption = special_toks_indices.get('', special_toks_indices[''] + 1) + attention_caption = attention_output[:, special_toks_indices['']: end_idx_caption, :] + + attention_temporal, attention_history = None, None + + if '' in special_toks_indices: + end_idx_temporal = special_toks_indices[''] + attention_temporal = attention_output[:, special_toks_indices['']:end_idx_temporal, :] + + if '' in special_toks_indices: + end_idx_history = special_toks_indices[''] + 1 + attention_history = attention_output[:, special_toks_indices['']:end_idx_history, :] + + # Expert activation + # 1- Spatial + intermediate_spatial = self.intermediate_spatial(attention_spatial) + output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial) + + output_vis = output_sapatial + + # 2- Temporal + if attention_temporal is not None: + intermediate_temporal = self.intermediate_temporal(attention_temporal) + output_temporal = self.output_temporal(intermediate_temporal, attention_temporal) + + attention_vis = torch.concat([output_sapatial, output_temporal], dim=1) + intermediate_vis = self.intermediate_vis(attention_vis) + output_vis = self.output_vis(intermediate_vis, attention_vis) + + # 3- Caption + intermediate_caption = self.intermediate_caption(attention_caption) + output_caption = self.output_caption(intermediate_caption, attention_caption) + + # 4- History + if attention_history is not None: + intermediate_history = self.intermediate_history(attention_history) + output_history = self.output_history(intermediate_history, attention_history) + + output_list = [unchanged, output_vis, output_caption] + + if attention_history is not None: + output_list.append(output_history) + + # Concat the features back + layer_output = torch.concat(output_list, dim=1) + assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format( + layer_output.size(1), len_init + ) + + elif expert_flag == 'fusion': + intermediate_output = self.intermediate_fusion(attention_output) + layer_output = self.output_fusion(intermediate_output, attention_output) + + return layer_output + + +class MoEPooler(nn.Module): + def __init__(self): + super(MoEPooler, self).__init__() + + self.bert_config = BertConfig.from_pretrained('bert-large-uncased') + hidden_size = self.bert_config.hidden_size + + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + self._init_weights() + + def _init_weights(self): + for _, m in dict(self.named_modules()).items(): + if isinstance(m, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range) + elif isinstance(m, nn.LayerNorm): + m.bias.data.zero_() + m.weight.data.fill_(1.0) + if isinstance(m, nn.Linear) and m.bias is not None: + m.bias.data.zero_() + + def forward(self, hidden_states, idx): + pooled_states = hidden_states[:, idx] + pooled_output = self.dense(pooled_states) + pooled_output = self.activation(pooled_output) + return pooled_output diff --git a/models/backbones/moes_original.py b/models/backbones/moes_original.py new file mode 100644 index 0000000..6f7c737 --- /dev/null +++ b/models/backbones/moes_original.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from timm.models.layers import DropPath + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if mask is not None: + if mask.dim() != x.dim(): + expanded_mask = mask[:, None, None, :].expand(B, 1, N, N) + else: + expanded_mask = mask + expanded_mask = expanded_mask.bool() + attn = attn.masked_fill(~expanded_mask, float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class MoELayer(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.SiLU, + norm_layer=LlamaRMSNorm, + ): + super().__init__() + self.norm_att = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # EXPERT CONSTRUCTION + mlp_hidden_dim = int(dim * mlp_ratio) + + + # Spatial expert + self.norm_spatial = norm_layer(dim) + self.mlp_spatial = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Temporal expert + self.norm_temp = norm_layer(dim) + self.mlp_temp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Vis expert + self.norm_vis = norm_layer(dim) + self.mlp_vis = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # caption expert + self.norm_cap = norm_layer(dim) + self.mlp_cap = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # history expert + self.norm_hist = norm_layer(dim) + self.mlp_hist = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Fusion expert + self.norm_fusion = norm_layer(dim) + self.mlp_fusion = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # expert_flag:{Only Text : 00 , Only Image : 01, Fusion : 10, Text & Image : 11} (BINARY) + + # expert_flag: + # 0: + + def forward(self, x, special_toks_indices, expert_flag, mask=None): + x_shortcut, attn = self.attn(self.norm_att(x), mask=mask) + x = x + self.drop_path(x_shortcut) + bs, h_dim = x.size(0), x.size(-1) + device = x.device + + if expert_flag == 'modalities': + end_index = special_toks_indices.get('', special_toks_indices['']) + spatial_feats = x[:, special_toks_indices['']: end_index, :] + spatial_feats = spatial_feats + self.drop_path(self.mlp_spatial(self.norm_spatial(spatial_feats))) + spatial_index = torch.arange(special_toks_indices[''], end_index, device=device) + spatial_index = spatial_index.unsqueeze(0).unsqueeze(-1) + spatial_index = spatial_index.repeat(bs, 1, h_dim) + x = x.scatter(1, spatial_index, spatial_feats) + # x[:, special_toks_indices['']: special_toks_indices[''], :] = spatial_feats + + end_index = special_toks_indices.get('', special_toks_indices['']) + caption_feats = x[:, special_toks_indices['']: end_index, :] + caption_feats = caption_feats + self.drop_path(self.mlp_cap(self.norm_cap(caption_feats))) + caption_index = torch.arange(special_toks_indices[''], end_index, device=device) + caption_index = caption_index.unsqueeze(0).unsqueeze(-1) + caption_index = caption_index.repeat(bs, 1, h_dim) + x = x.scatter(1, caption_index, caption_feats) + + # x[:, special_toks_indices['']: special_toks_indices[''], :] = caption_feats + + if '' in special_toks_indices: + temporal_feats = x[:, special_toks_indices['']: special_toks_indices[''], :] + temporal_feats = temporal_feats + self.drop_path(self.mlp_temp(self.norm_temp(temporal_feats))) + temporal_index = torch.arange(special_toks_indices[''], special_toks_indices[''], device=device) + temporal_index = temporal_index.unsqueeze(0).unsqueeze(-1) + temporal_index = temporal_index.repeat(bs, 1, h_dim) + x = x.scatter(1, temporal_index, temporal_feats) + + # x[:, special_toks_indices['']: special_toks_indices[''], :] = temporal_feats + + vis_feats = x[:, special_toks_indices['']: special_toks_indices[''], :] + vis_feats = vis_feats + self.drop_path(self.mlp_vis(self.norm_vis(vis_feats))) + vis_index = torch.arange(special_toks_indices[''], special_toks_indices[''], device=device) + vis_index = vis_index.unsqueeze(0).unsqueeze(-1) + vis_index = vis_index.repeat(bs, 1, h_dim) + x = x.scatter(1, vis_index, vis_feats) + + # x[:, special_toks_indices['']: special_toks_indices[''], :] = vis_feats + + if '' in special_toks_indices: + history_feats = x[:, special_toks_indices['']: special_toks_indices[''], :] + history_feats = history_feats + self.drop_path(self.mlp_hist(self.norm_hist(history_feats))) + history_index = torch.arange(special_toks_indices[''], special_toks_indices[''], device=device) + history_index = history_index.unsqueeze(0).unsqueeze(-1) + history_index = history_index.repeat(bs, 1, h_dim) + x = x.scatter(1, history_index, history_feats) + + elif expert_flag == 'fusion': + x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x))) + + return x, attn + + # if expert_flag == 2: + # x = x + self.drop_path(self.mlp(self.norm2(x))) + # elif expert_flag == 0: + # x = (x[:, -it_split:]) + # x = x + self.drop_path(self.sentence_mlp(self.sentence_norm(x))) + # elif expert_flag == 1: + # x = (x[:, :-it_split ]) + # x = x + self.drop_path(self.image_mlp(self.image_norm(x))) + # elif expert_flag == 3: + # text, image = (x[:, :it_split], x[:, it_split:],) + # text = text + self.drop_path(self.sentence_mlp(self.sentence_norm(text))) + # image = image + self.drop_path(self.image_mlp(self.image_norm(image))) + # x = torch.cat([text, image], dim=1) + # elif expert_flag == 4: + # x = x + self.drop_path(self.generation_mlp(self.generation_norm(x))) + # return x, attn \ No newline at end of file diff --git a/models/common/__init__.py b/models/common/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/models/common/config.py b/models/common/config.py new file mode 100755 index 0000000..0d092a3 --- /dev/null +++ b/models/common/config.py @@ -0,0 +1,474 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from minigpt4.common.registry import registry + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + dataset_config = self.build_dataset_config(config) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + print("--------------") + print("model arch",model.arch) + print("model cls",model_cls) + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hierarchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + + print("dataset name", dataset_name) + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/models/common/dist_utils.py b/models/common/dist_utils.py new file mode 100755 index 0000000..07919b0 --- /dev/null +++ b/models/common/dist_utils.py @@ -0,0 +1,203 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if args.distributed is False: + print("Not using distributed mode") + args.rank = 0 + return + + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + args.rank = 0 + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/models/common/eval_utils.py b/models/common/eval_utils.py new file mode 100644 index 0000000..0450873 --- /dev/null +++ b/models/common/eval_utils.py @@ -0,0 +1,224 @@ +import argparse +import numpy as np +from nltk.translate.bleu_score import sentence_bleu +import sys +sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img') +from minigpt4.common.registry import registry +from minigpt4.common.config import Config + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +# from minigpt4.runners import * +from minigpt4.tasks import * +from pycocoevalcap.cider.cider import Cider +import os +import openai +from tqdm import tqdm +import json +import ast +import time + +def eval_parser(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") + parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") + parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") + parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") + parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") + parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + return parser + + +def prepare_texts(texts, conv_temp, template='', lengths=None): + convs = [conv_temp.copy() for _ in range(len(texts))] + if lengths is None: + [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)] + else: + templates = [template * length for length in lengths] + [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)] + [conv.append_message(conv.roles[1], None) for conv in convs] + texts = [conv.get_prompt() for conv in convs] + return texts + + +def init_model(args): + print('Initialization Model') + cfg = Config(args) + cfg.model_cfg.ckpt = args.ckpt + cfg.model_cfg.lora_r = args.lora_r + cfg.model_cfg.lora_alpha = args.lora_alpha + + model_config = cfg.model_cfg + model_config.low_resource = True + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to('cuda:0') + +# import pudb; pudb.set_trace() + key = list(cfg.datasets_cfg.keys())[0] + vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train + print(vis_processor_cfg) + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + print('Initialization Finished') + return model, vis_processor + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou + +def eval_bleu(results): + bleus1,bleus2,bleus3,bleus4 = [],[],[],[] + for result in tqdm (results,desc="bleu_eval"): + gt = result['gt'] + pred = result['pred'] + bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0))) + bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0))) + bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0))) + bleus4.append(sentence_bleu([gt.split()], pred.split())) + # print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True) + return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)} + +# Create a Cider object +cider_scorer = Cider() +def eval_cider(pred_result,gt_result): + # Compute CIDEr scores + mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result) + cider_scores_dict={} + for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") : + assert pred_vid_id==gt_vid_id + cider_scores_dict[pred_vid_id] = score + return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict} + + +openai.api_key_path = "/home/ataallka/chatgpt_api.txt" + + +def chat_gpt_eval(results,output_path): + trial=0 + gpt_results=[] + avg_chatgpt_score=0 + existed_files={} + # read previous results from output path + for file in os.listdir(output_path): + if file.endswith(".json"): + with open(f'{output_path}/{file}') as json_file: + data = json.load(json_file) + gpt_results.append(data[0]) + avg_chatgpt_score+=float(data[0]['chatgpt_score']) + existed_files[data[0]['video_name']]=True + length_output_path=len(os.listdir(output_path)) + while len (results)!= length_output_path: + for res in tqdm(results,desc="chatgpt_eval"): + if existed_files.get(res['video_name'],False): + continue + video_name=res['video_name'] + sentence_1=res['A'] + sentence_2=res['pred'] + try: + # prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:" + prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:" + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": prompt + }], + ) + res['chatgpt_score']=response.choices[0].message['content'] + out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']} + gpt_results.append(out) + # save each video result in a json file + with open(f'{output_path}/{video_name}.json', 'w') as f: + json.dump([out], f) + avg_chatgpt_score+=float(response.choices[0].message['content']) + except Exception as e: + print("chat gpt error",e) + print ("Finished chat gpt evaluation in trial",trial) + trial+=1 + length_output_path=len(os.listdir(output_path)) + return results,avg_chatgpt_score/len(results) +def GPT4_answer(question, answer,pred): + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + # model="gpt-3.5-turbo", + model='gpt-4', + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + return response_dict + except Exception as e: + print(f"Error : {e}") + return None +def GPT4_evaluation(val_result): + scores=[] + yes_count=0 + no_count=0 + for res in val_result: + gpt_response=GPT4_answer(res['Q'],res['A'],res['pred']) + if gpt_response is None: + continue + try: + scores.append(float(gpt_response['score'])) + if 'yes' in gpt_response['pred'].lower(): + yes_count+=1 + elif 'no' in gpt_response['pred'].lower(): + no_count+=1 + except: + continue + avg_score=sum(scores)/len(scores) + accuracy=(yes_count/(yes_count+no_count))*100 + print(f"chatgpt score: {avg_score} accuracy: {accuracy}") + return avg_score,accuracy + +# with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f: +# results = json.load(f) +# t1=time.time() +# avg_score,accuracy=GPT4_evaluation(results) +# print(f"chatgpt score: {avg_score} accuracy: {accuracy}") +# print(f"Time taken: {time.time()-t1}") \ No newline at end of file diff --git a/models/common/gradcam.py b/models/common/gradcam.py new file mode 100755 index 0000000..d53a525 --- /dev/null +++ b/models/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/models/common/logger.py b/models/common/logger.py new file mode 100755 index 0000000..9a5a727 --- /dev/null +++ b/models/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from minigpt4.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/models/common/optims.py b/models/common/optims.py new file mode 100755 index 0000000..270e66b --- /dev/null +++ b/models/common/optims.py @@ -0,0 +1,119 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from minigpt4.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + iters_per_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.iters_per_epoch = iters_per_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + total_cur_step = cur_epoch * self.iters_per_epoch + cur_step + if total_cur_step < self.warmup_steps: + warmup_lr_schedule( + step=total_cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=total_cur_step, + optimizer=self.optimizer, + max_epoch=self.max_epoch * self.iters_per_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/models/common/registry.py b/models/common/registry.py new file mode 100755 index 0000000..c953097 --- /dev/null +++ b/models/common/registry.py @@ -0,0 +1,330 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from minigpt4.common.registry import registry + from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder + + assert issubclass( + builder_cls, BaseDatasetBuilder + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(task_cls): + from minigpt4.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(model_cls): + # from minigpt4.models import BaseModel + + # assert issubclass( + # model_cls, BaseModel + # ), "All models must inherit BaseModel class" + + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(processor_cls): + from minigpt4.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from minigpt4.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/models/common/utils.py b/models/common/utils.py new file mode 100755 index 0000000..f665d5b --- /dev/null +++ b/models/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from models.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M") + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py new file mode 100644 index 0000000..07ca21d --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +import sys +dataDir = '../../VQA' +sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir)) +from vqa import VQA +from vqaEvaluation.vqaEval import VQAEval +import matplotlib.pyplot as plt +import skimage.io as io +import json +import random +import os + +# set up file names and paths +versionType ='v2_' # this should be '' when using VQA v2.0 dataset +taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 +dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. +dataSubType ='train2014' +annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) +quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) +imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) +resultType ='fake' +fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType'] + +# An example result json file has been provided in './Results' folder. + +[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \ +resultType, fileType) for fileType in fileTypes] + +# create vqa object and vqaRes object +vqa = VQA(annFile, quesFile) +vqaRes = vqa.loadRes(resFile, quesFile) + +# create vqaEval object by taking vqa and vqaRes +vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2 + +# evaluate results +""" +If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function +By default it uses all the question ids in annotation file +""" +vqaEval.evaluate() + +# print accuracies +print "\n" +print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']) +print "Per Question Type Accuracy is the following:" +for quesType in vqaEval.accuracy['perQuestionType']: + print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType]) +print "\n" +print "Per Answer Type Accuracy is the following:" +for ansType in vqaEval.accuracy['perAnswerType']: + print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType]) +print "\n" +# demo how to use evalQA to retrieve low score result +evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy +if len(evals) > 0: + print 'ground truth answers' + randomEval = random.choice(evals) + randomAnn = vqa.loadQA(randomEval) + vqa.showQA(randomAnn) + + print '\n' + print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval]) + ann = vqaRes.loadQA(randomEval)[0] + print "Answer: %s\n" %(ann['answer']) + + imgId = randomAnn[0]['image_id'] + imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' + if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# plot accuracy for various question types +plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center') +plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10) +plt.title('Per Question Type Accuracy', fontsize=10) +plt.xlabel('Question Types', fontsize=10) +plt.ylabel('Accuracy', fontsize=10) +plt.show() + +# save evaluation results to ./Results folder +json.dump(vqaEval.accuracy, open(accuracyFile, 'w')) +json.dump(vqaEval.evalQA, open(evalQAFile, 'w')) +json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w')) +json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w')) + diff --git a/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py new file mode 100644 index 0000000..148424d --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py @@ -0,0 +1 @@ +author='aagrawal' diff --git a/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py new file mode 100644 index 0000000..8a65604 --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py @@ -0,0 +1,192 @@ +# coding=utf-8 + +__author__='aagrawal' + +import re +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys + + +class VQAEval: + def __init__(self, vqa, vqaRes, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + self.params = {'question_id': vqa.getQuesIds()} + self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ + "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ + "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ + "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ + "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ + "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ + "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ + "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ + "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ + "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ + "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ + "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ + "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ + "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ + "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ + "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ + "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ + "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ + "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ + "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ + "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ + "youll": "you'll", "youre": "you're", "youve": "you've"} + self.manualMap = { 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10' + } + self.articles = ['a', + 'an', + 'the' + ] + + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [';', r"/", '[', ']', '"', '{', '}', + '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!'] + + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params['question_id']] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + # print "computing accuracy" + step = 0 + for quesId in quesIds: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = ansDic['answer'].replace('\n', ' ') + ansDic['answer'] = ansDic['answer'].replace('\t', ' ') + ansDic['answer'] = ansDic['answer'].strip() + resAns = res[quesId]['answer'] + resAns = resAns.replace('\n', ' ') + resAns = resAns.replace('\t', ' ') + resAns = resAns.strip() + gtAcc = [] + gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] + + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = self.processPunctuation(ansDic['answer']) + ansDic['answer'] = self.processDigitArticle(ansDic['answer']) + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + + for gtAnsDatum in gts[quesId]['answers']: + otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] + matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()] + acc = min(1, float(len(matchingAns))/3) + gtAcc.append(acc) + quesType = gts[quesId]['question_type'] + ansType = gts[quesId]['answer_type'] + avgGTAcc = float(sum(gtAcc))/len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step%100 == 0: + self.updateProgress(step/float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + # print "Done computing accuracy" + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = self.periodStrip.sub("", + outText, + re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = ' '.join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) + self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} + self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100*acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100*acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100*acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength*progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/models/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py b/models/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py new file mode 100644 index 0000000..406b596 --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py @@ -0,0 +1,73 @@ +# coding: utf-8 + +from vqaTools.vqa import VQA +import random +import skimage.io as io +import matplotlib.pyplot as plt +import os + +dataDir ='../../VQA' +versionType ='v2_' # this should be '' when using VQA v2.0 dataset +taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 +dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. +dataSubType ='train2014' +annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) +quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) +imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) + +# initialize VQA api for QA annotations +vqa=VQA(annFile, quesFile) + +# load and display QA annotations for given question types +""" +All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder. +""" +annIds = vqa.getQuesIds(quesTypes='how many'); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# load and display QA annotations for given answer types +""" +ansTypes can be one of the following +yes/no +number +other +""" +annIds = vqa.getQuesIds(ansTypes='yes/no'); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# load and display QA annotations for given images +""" +Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[]) +Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types. +""" +ids = vqa.getImgIds() +annIds = vqa.getQuesIds(imgIds=random.sample(ids,5)); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + diff --git a/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py b/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py new file mode 100644 index 0000000..072d8d9 --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py @@ -0,0 +1 @@ +__author__ = 'aagrawal' diff --git a/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py b/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py new file mode 100644 index 0000000..4f76961 --- /dev/null +++ b/models/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py @@ -0,0 +1,179 @@ +__author__ = 'aagrawal' +__version__ = '0.9' + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + # print 'loading VQA annotations and questions into memory...' + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, 'r')) + questions = json.load(open(question_file, 'r')) + # print datetime.datetime.utcnow() - time_t + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} + qa = {ann['question_id']: [] for ann in self.dataset['annotations']} + qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} + for ann in self.dataset['annotations']: + imgToQA[ann['image_id']] += [ann] + qa[ann['question_id']] = ann + for ques in self.questions['questions']: + qqa[ques['question_id']] = ques + # print 'index created!' + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + + # for key, value in self.datset['info'].items(): + # print '%s: %s'%(key, value) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(imgIds) == 0: + anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], []) + else: + anns = self.dataset['annotations'] + anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] + anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] + ids = [ann['question_id'] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(quesIds) == 0: + anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], []) + else: + anns = self.dataset['annotations'] + anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] + anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] + ids = [ann['image_id'] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann['question_id'] + print("Question: %s" % (self.qqa[quesId]['question'])) + for ans in ann['answers']: + print("Answer %d: %s" % (ans['answer_id'], ans['answer'])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset['info'] = copy.deepcopy(self.questions['info']) + res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) + res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) + res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) + res.dataset['license'] = copy.deepcopy(self.questions['license']) + + # print 'Loading and preparing results... ' + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, 'results is not an array of objects' + annsQuesIds = [ann['question_id'] for ann in anns] + assert set(annsQuesIds) == set(self.getQuesIds()), \ + 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' + for ann in anns: + quesId = ann['question_id'] + if res.dataset['task_type'] == 'Multiple Choice': + assert ann['answer'] in self.qqa[quesId][ + 'multiple_choices'], 'predicted answer is not one of the multiple choices' + qaAnn = self.qa[quesId] + ann['image_id'] = qaAnn['image_id'] + ann['question_type'] = qaAnn['question_type'] + ann['answer_type'] = qaAnn['answer_type'] + # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()) + + res.dataset['annotations'] = anns + res.createIndex() + return res diff --git a/models/common/vqa_tools/VQA/README.md b/models/common/vqa_tools/VQA/README.md new file mode 100644 index 0000000..439d59d --- /dev/null +++ b/models/common/vqa_tools/VQA/README.md @@ -0,0 +1,80 @@ +Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset. +=================== +## VQA v2.0 release ## +This release consists of +- Real + - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) + - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing + - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question) + +There is only one type of task +- Open-ended task + +## VQA v1.0 release ## +This release consists of +- Real + - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) + - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image) + - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question) +- Abstract + - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images + - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image) + - 600,000 answers for training and 300,000 answers for validation (10 per question) + +There are two types of tasks +- Open-ended task +- Multiple-choice task (18 choices per question) + +## Requirements ## +- python 2.7 +- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation) +- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation) + +## Files ## +./Questions +- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. +- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). +- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below + - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip) + - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip) +- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip). + +./Annotations +- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. +- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). +- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below + - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip) + - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip) +- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip). + +./Images +- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders. +- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders. + +./PythonHelperTools +- This directory contains the Python API to read and visualize the VQA dataset +- vqaDemo.py (demo script) +- vqaTools (API to read and visualize data) + +./PythonEvaluationTools +- This directory contains the Python evaluation code +- vqaEvalDemo.py (evaluation demo script) +- vqaEvaluation (evaluation code) + +./Results +- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo) +- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details. + +./QuestionTypes +- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k. +- mscoco_question_types.txt +- abstract_v002_question_types.txt + +## References ## +- [VQA: Visual Question Answering](http://visualqa.org/) +- [Microsoft COCO](http://mscoco.org/) + +## Developers ## +- Aishwarya Agrawal (Virginia Tech) +- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco). +- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption). diff --git a/models/common/vqa_tools/__init__.py b/models/common/vqa_tools/__init__.py new file mode 100644 index 0000000..9b98da8 --- /dev/null +++ b/models/common/vqa_tools/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" diff --git a/models/common/vqa_tools/aokvqa/LICENSE b/models/common/vqa_tools/aokvqa/LICENSE new file mode 100644 index 0000000..663d675 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 Allen Institute for Artificial Intelligence + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/models/common/vqa_tools/aokvqa/README.md b/models/common/vqa_tools/aokvqa/README.md new file mode 100644 index 0000000..21caefa --- /dev/null +++ b/models/common/vqa_tools/aokvqa/README.md @@ -0,0 +1,207 @@ +# A-OKVQA + +Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**. + +Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public) + +### Abstract + +The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models. + +![dataset_web](https://user-images.githubusercontent.com/28768645/170799740-f0d9ea60-6aff-4322-98d5-cae8e05983f4.svg) + +
+ +#### Table of Contents + +- [Getting started](#getting-started) + * [Downloading the dataset](#downloading-the-dataset) +- [Evaluation & Leaderboard](#evaluation) +- [Codebase](#codebase) + * [Preparing data](#preparing-data) + * [Models and Predictions](#models-and-predictions) + +
+ +## Getting started + +```bash +git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git + +cd aokvqa +export PYTHONPATH=. + +conda env create --name aokvqa +conda activate aokvqa +``` + +### Downloading the dataset + +```bash +export AOKVQA_DIR=./datasets/aokvqa/ +mkdir -p ${AOKVQA_DIR} + +curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR} +``` + +
Downloading COCO 2017 + +```bash +export COCO_DIR=./datasets/coco/ +mkdir -p ${COCO_DIR} + +for split in train val test; do + wget "http://images.cocodataset.org/zips/${split}2017.zip" + unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip" +done + +wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip +unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip +``` + +
+ +Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code. + +```python +import os +aokvqa_dir = os.getenv('AOKVQA_DIR') + +from load_aokvqa import load_aokvqa, get_coco_path +train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test' +``` + +
Example dataset entry + +```python +dataset_example = train_dataset[0] + +print(dataset_example['question_id']) +# 22MexNkBPpdZGX6sxbxVBH + +coco_dir = os.getenv('COCO_DIR') +image_path = get_coco_path('train', dataset_example['image_id'], coco_dir) +print(image_path) +# ./datasets/coco/train2017/000000299207.jpg + +print(dataset_example['question']) +print(dataset_example['choices']) +# What is the man by the bags awaiting? +# ['skateboarder', 'train', 'delivery', 'cab'] + +correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ] +# Corrrect: cab + +print(dataset_example['rationales'][0]) +# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer. +``` + +
+ +## Evaluation + +Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting. + +```python +{ + '' : { + 'multiple_choice' : '', + 'direct_answer' : '' + } +} +``` + +You can run evaluation on the validation set as follows. + +```bash +python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json +``` + +### Leaderboard + +You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started). + +## Codebase + +We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3). + +### Preparing data + +```bash +export FEATURES_DIR=./features/ +mkdir -p ${FEATURES_DIR} +``` + +You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments. + +```bash +python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt + +for split in train val test; do + python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt +done +``` + +
For training ClipCap with a transformer mapping network + +If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`. + +
+ +
For ResNet and BERT input features + +Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands: + +```bash +# ResNet +for split in train val test; do + python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt +done + +# BERT +for split in train val test; do + python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt +done +``` + +
+ +### Models and Predictions + +```bash +export LOG_DIR=./logs/ +export PREDS_DIR=./predictions/ +export PT_MODEL_DIR=./pretrained_models/ +mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR} +``` + +
Download our pretrained model weights + +```bash +# Checkpoints for transfer learning experiments +curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models + +# Checkpoints for ClipCap models (generating answers and rationales) +curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models +``` + +
+ +We have included instructions for replicating each of our experiments (see README.md files below). + +All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above. + +- [Heuristics](./heuristics/README.md) +- [Transfer Learning Experiments](./transfer_experiments/README.md) +- [Querying GPT-3](./gpt3/README.md) +- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md) +- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md) + +For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set. + +We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.) + +```bash +python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json +# repeat for test split ... +``` diff --git a/models/common/vqa_tools/aokvqa/data_scripts/build_vocab.py b/models/common/vqa_tools/aokvqa/data_scripts/build_vocab.py new file mode 100644 index 0000000..2c44686 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/data_scripts/build_vocab.py @@ -0,0 +1,45 @@ +import os +import argparse +from collections import Counter +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + + +# Build vocab from train set: correct choices + (direct answers appearing in >= 3 ) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') + +vocab = [] +all_choices = Counter() +direct_answers = Counter() + +for i in train_set: + vocab.append( i['choices'][i['correct_choice_idx']] ) + all_choices.update(i['choices']) + direct_answers.update(set(i['direct_answers'])) +vocab += [k for k,v in all_choices.items() if v >= 3] +vocab += [k for k,v in direct_answers.items() if v >= 3] + +vocab = sorted(set(vocab)) +print(f"Vocab size: {len(vocab)}") + +# Save vocabulary Output + +with open(args.output_file, 'w') as f: + for v in vocab: + print(v, file=f) + +## Check validation set coverage + +val_set = load_aokvqa(args.aokvqa_dir, 'val') + +val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set] +val_acc = sum(val_acc) / len(val_acc) * 100 +print(f"Val set coverage: {val_acc:.2f}" ) diff --git a/models/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py b/models/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py new file mode 100644 index 0000000..1dce760 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py @@ -0,0 +1,26 @@ +import json +from tqdm import tqdm +import argparse +import pathlib + +import torch +import clip + +parser = argparse.ArgumentParser() +parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file') +parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +device = "cuda" if torch.cuda.is_available() else "cpu" +model, preprocess = clip.load(args.model_type, device=device) + +with torch.no_grad(): + a = open(args.vocab_file).read().splitlines() + mc_text = clip.tokenize(a).to(device) + mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0] + mc_text_features = mc_text_features.float() + model_name = args.model_type.replace('/', '-').replace('@', '-') + torch.save(mc_text_features, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py b/models/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py new file mode 100644 index 0000000..60cd40f --- /dev/null +++ b/models/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py @@ -0,0 +1,50 @@ +import os +import argparse +import pathlib +from tqdm import tqdm + +import torch +from transformers import AutoTokenizer, AutoModel + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') +model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) +model.eval() + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt') + encoded_input = {k:v.to(device) for k,v in encoded_input.items()} + e = mean_pooling(model(**encoded_input), encoded_input['attention_mask']) + embeddings[d['question_id']] = { + 'question' : e[0].cpu() + } + + torch.save(embeddings, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py b/models/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py new file mode 100644 index 0000000..20d0455 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py @@ -0,0 +1,51 @@ +import os +from PIL import Image +from tqdm import tqdm +import argparse +import pathlib + +import torch +import clip + +from load_aokvqa import load_aokvqa, get_coco_path + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +device = "cuda" if torch.cuda.is_available() else "cpu" +model, preprocess = clip.load(args.model_type, device=device) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + q = d["question"] + q_text = clip.tokenize(q).to(device) + q_text_features = model.encode_text(q_text) + + img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)) + img = preprocess(img).unsqueeze(0).to(device) + image_features = model.encode_image(img) + + embeddings[d['question_id']] = { + 'question' : q_text_features[0].float().cpu(), + 'image' : image_features[0].float().cpu(), + } + + torch.save(embeddings, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py b/models/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py new file mode 100644 index 0000000..0d7277b --- /dev/null +++ b/models/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py @@ -0,0 +1,62 @@ +import os +import argparse +import pathlib +from tqdm import tqdm +from PIL import Image + +import torch +import torch.nn as nn +from torchvision import models +from torchvision import transforms as T + +from load_aokvqa import load_aokvqa, get_coco_path + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +resnet_preprocess = T.Compose([ + T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC), + T.CenterCrop(size=(224, 224)), + T.ToTensor(), + T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) +]) + +device = "cuda" if torch.cuda.is_available() else "cpu" + +resnet_model = models.resnet50(pretrained=True) +resnet_model = torch.nn.Sequential( + *list(resnet_model.children())[:-1], + nn.Flatten() +) # strip classification layer +resnet_model = resnet_model.to(device) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB') + resnet_input = resnet_preprocess(img).unsqueeze(0).to(device) + resnet_features = resnet_model(resnet_input) + embeddings[d['question_id']] = { + 'image' : resnet_features[0].cpu() + } + + torch.save(embeddings, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/environment.yml b/models/common/vqa_tools/aokvqa/environment.yml new file mode 100644 index 0000000..58284ec --- /dev/null +++ b/models/common/vqa_tools/aokvqa/environment.yml @@ -0,0 +1,36 @@ +name: aokvqa +channels: + - pytorch + - nvidia + - huggingface + - conda-forge + - defaults +dependencies: + - python=3.7 + - cudatoolkit=11.3 + - numpy=1.21.6 + - pytorch=1.11.0 + - torchvision=0.12.0 + - pytorch-lightning=1.6.3 + - torchmetrics=0.8.1 + - gdown=4.4.0 + - pip=22.0.4 + - pip: + - argparse==1.4.0 + - Pillow==9.0.1 + - tensorboard==2.9.0 + - ftfy==6.1.1 + - regex==2022.3.15 + - tqdm==4.64.0 + - clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620 + - openai==0.18.1 + - nltk==3.7 + - sacrebleu==2.0.0 + - sacremoses==0.0.53 + - sentence-transformers==2.2.0 + - datasets==2.1.0 + - tokenizers==0.10.3 + - transformers==4.10.3 + +# Next: resolve conflict between sentence-transfomers and pytorch-lightning +# pip uninstall sentencepiece diff --git a/models/common/vqa_tools/aokvqa/evaluation/eval_predictions.py b/models/common/vqa_tools/aokvqa/evaluation/eval_predictions.py new file mode 100644 index 0000000..a7b5dbe --- /dev/null +++ b/models/common/vqa_tools/aokvqa/evaluation/eval_predictions.py @@ -0,0 +1,97 @@ +import argparse +import pathlib +import json +import glob + +from load_aokvqa import load_aokvqa + + +def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True): + + if isinstance(dataset, list): + dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) } + + if multiple_choice is False: + dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False} + + if strict: + dataset_qids = set(dataset.keys()) + preds_qids = set(preds.keys()) + assert dataset_qids.issubset(preds_qids) + + # dataset = q_id (str) : dataset element (dict) + # preds = q_id (str) : prediction (str) + + acc = [] + + for q in dataset.keys(): + if q not in preds.keys(): + acc.append(0.0) + continue + + pred = preds[q] + choices = dataset[q]['choices'] + direct_answers = dataset[q]['direct_answers'] + + ## Multiple Choice setting + if multiple_choice: + if strict: + assert pred in choices, 'Prediction must be a valid choice' + correct_choice_idx = dataset[q]['correct_choice_idx'] + acc.append( float(pred == choices[correct_choice_idx]) ) + ## Direct Answer setting + else: + num_match = sum([pred.lower() == da.lower() for da in direct_answers]) + vqa_acc = min(1.0, num_match / 3.0) + acc.append(vqa_acc) + + acc = sum(acc) / len(acc) * 100 + + return acc + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--preds', type=str, required=True, dest='prediction_files') + args = parser.parse_args() + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + + for prediction_file in glob.glob(args.prediction_files): + predictions = json.load(open(prediction_file, 'r')) + + # Multiple choice + + mc_predictions = {} + + for q in predictions.keys(): + if 'multiple_choice' in predictions[q].keys(): + mc_predictions[q] = predictions[q]['multiple_choice'] + + if mc_predictions != {}: + mc_acc = eval_aokvqa( + dataset, + mc_predictions, + multiple_choice=True, + strict=False + ) + print(prediction_file, 'MC', mc_acc) + + # Direct Answer + + da_predictions = {} + + for q in predictions.keys(): + if 'direct_answer' in predictions[q].keys(): + da_predictions[q] = predictions[q]['direct_answer'] + + if da_predictions != {}: + da_acc = eval_aokvqa( + dataset, + da_predictions, + multiple_choice=False, + strict=False + ) + print(prediction_file, 'DA', da_acc) diff --git a/models/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py b/models/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py new file mode 100644 index 0000000..3e3dd49 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py @@ -0,0 +1,13 @@ +import os +import json + + +def load_aokvqa(aokvqa_dir, split, version='v1p0'): + assert split in ['train', 'val', 'test', 'test_w_ans'] + dataset = json.load(open( + os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json") + )) + return dataset + +def get_coco_path(split, image_id, coco_dir): + return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg") diff --git a/models/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py b/models/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py new file mode 100644 index 0000000..202f00c --- /dev/null +++ b/models/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py @@ -0,0 +1,31 @@ +import argparse +import pathlib +import json + +from load_aokvqa import load_aokvqa + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file') + parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file') + parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file') + args = parser.parse_args() + assert args.mc_pred_file or args.da_pred_file + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None + da_preds = json.load(args.da_pred_file) if args.da_pred_file else None + predictions = {} + + for d in dataset: + q = d['question_id'] + predictions[q] = {} + if mc_preds and q in mc_preds.keys(): + predictions[q]['multiple_choice'] = mc_preds[q] + if da_preds and q in da_preds.keys(): + predictions[q]['direct_answer'] = da_preds[q] + + json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/evaluation/remap_predictions.py b/models/common/vqa_tools/aokvqa/evaluation/remap_predictions.py new file mode 100644 index 0000000..40ba155 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/evaluation/remap_predictions.py @@ -0,0 +1,44 @@ +import argparse +import pathlib +import json +from tqdm import tqdm + +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim + +from load_aokvqa import load_aokvqa + + +def map_to_choices(dataset, predictions, device='cpu'): + if isinstance(dataset, list): + dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) } + + if all([p in dataset[q]['choices'] for q, p in predictions.items()]): + return predictions + + model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d') + model.to(device) + for q in tqdm(predictions.keys()): + choices = dataset[q]['choices'] + if predictions[q] not in choices: + choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True) + a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item() + predictions[q] = choices[a_idx] + + return predictions + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file') + parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') + args = parser.parse_args() + + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + predictions = json.load(args.prediction_file) + predictions = map_to_choices(dataset, predictions) + + json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/gpt3/README.md b/models/common/vqa_tools/aokvqa/gpt3/README.md new file mode 100644 index 0000000..fc1fd6b --- /dev/null +++ b/models/common/vqa_tools/aokvqa/gpt3/README.md @@ -0,0 +1,14 @@ +## Querying GPT-3 + +To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables. + +```bash +export OPENAI_ORG=.... +export OPENAI_API_KEY=... +``` + +For producing predictions for both DA and MC settings, run: +```bash +python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json +python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json +``` diff --git a/models/common/vqa_tools/aokvqa/gpt3/caption_inputs.py b/models/common/vqa_tools/aokvqa/gpt3/caption_inputs.py new file mode 100644 index 0000000..2117434 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/gpt3/caption_inputs.py @@ -0,0 +1,23 @@ +import os +import json +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val'], required=True) +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) + +coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations'] +coco_captions = {c['image_id'] : c['caption'] for c in coco_captions} + +captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set } + +json.dump(captions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/gpt3/query_gpt3.py b/models/common/vqa_tools/aokvqa/gpt3/query_gpt3.py new file mode 100644 index 0000000..4a08900 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/gpt3/query_gpt3.py @@ -0,0 +1,79 @@ +import os +import random +import json +from tqdm import tqdm +import argparse +import pathlib + +import openai +openai.organization = os.getenv('OPENAI_ORG') +openai.api_key = os.getenv('OPENAI_API_KEY') + +from load_aokvqa import load_aokvqa + + +random.seed(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--n', type=int, default=10, dest='num_examples') + parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file') + parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix') + parser.add_argument('--include-choices', action='store_true', dest='include_choices') + parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file') + parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') + args = parser.parse_args() + + + train_set = load_aokvqa(args.aokvqa_dir, 'train') + eval_set = load_aokvqa(args.aokvqa_dir, args.split) + + train_context = {} + context = {} + if args.context_file is not None: + train_context = json.load(args.train_context_file) + context = json.load(args.context_file) + + predictions = {} + + for d in tqdm(eval_set): + q = d['question_id'] + + prompt = args.prompt_prefix + for e in random.sample(train_set, args.num_examples): + prompt += prompt_element(e, + context=train_context.get(q, None), + include_choices=args.include_choices, + answer=True + ) + prompt += '\n\n' + + prompt += prompt_element(d, + context=context.get(q, None), + include_choices=args.include_choices, + answer=False + ) + + response = openai.Completion.create( + engine="text-curie-001", + prompt=prompt, + temperature=0.0, + max_tokens=10, + ) + + predictions[q] = response.choices[0].text.strip() + + json.dump(predictions, args.output_file) + + +def prompt_element(d, context=None, include_choices=False, answer=False): + return (f"Context: {context}\n" if context is not None else '') + \ + f"Q: {d['question']}\n" + \ + (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \ + f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '') + +if __name__ == '__main__': + main() diff --git a/models/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py b/models/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py new file mode 100644 index 0000000..411d1ee --- /dev/null +++ b/models/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py @@ -0,0 +1,16 @@ +import json +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True) +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) +rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set} +json.dump(rationales, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/heuristics/README.md b/models/common/vqa_tools/aokvqa/heuristics/README.md new file mode 100644 index 0000000..67c8632 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/heuristics/README.md @@ -0,0 +1,11 @@ +## Heuristics + +```bash +# These scripts accept the same arguments. +# heuristics/random_unweighted.py +# heuristics/random_weighted.py +# heuristics/most_common_answer.py + +python heuristics/random_unweighted.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc --out ${PREDS_DIR}/random-unweighted_val-mc.json +# Exclude --mc for the direct answer setting +``` diff --git a/models/common/vqa_tools/aokvqa/heuristics/most_common_answer.py b/models/common/vqa_tools/aokvqa/heuristics/most_common_answer.py new file mode 100644 index 0000000..59a27bc --- /dev/null +++ b/models/common/vqa_tools/aokvqa/heuristics/most_common_answer.py @@ -0,0 +1,39 @@ +import os +import json +import argparse +import pathlib +from collections import Counter + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + + +train_set = load_aokvqa(args.aokvqa_dir, 'train') +train_freq = dict(Counter( + [d['choices'][d['correct_choice_idx']] for d in train_set] +)) +most_common_answer = max(train_freq.keys(), key=train_freq.get) + +## + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +predictions = {} + +for d in eval_set: + q = d['question_id'] + predictions[q] = most_common_answer + + if args.multiple_choice: + choices = [c for c in d['choices'] if c in train_freq.keys()] + if len(choices) > 0: + predictions[q] = max(choices, key=train_freq.get) + +json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/heuristics/random_unweighted.py b/models/common/vqa_tools/aokvqa/heuristics/random_unweighted.py new file mode 100644 index 0000000..cfcf900 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/heuristics/random_unweighted.py @@ -0,0 +1,38 @@ +import os +import json +from random import seed, sample +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +seed(0) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') + +if args.multiple_choice is False: + choices = list(set( + [d['choices'][d['correct_choice_idx']] for d in train_set] + )) + +## + +predictions = {} + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +for d in eval_set: + q = d['question_id'] + if args.multiple_choice: + choices = d['choices'] + predictions[q] = sample(choices, 1)[0] + +json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/heuristics/random_weighted.py b/models/common/vqa_tools/aokvqa/heuristics/random_weighted.py new file mode 100644 index 0000000..2ccfa61 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/heuristics/random_weighted.py @@ -0,0 +1,46 @@ +import os +import json +import numpy as np +import argparse +import pathlib +from collections import Counter + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +np.random.seed(0) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') +train_freq = dict(Counter( + [d['choices'][d['correct_choice_idx']] for d in train_set] +)) + +if args.multiple_choice is False: + choices = list(train_freq.keys()) + probs = [f / len(train_set) for f in train_freq.values()] + +## + +predictions = {} + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +for d in eval_set: + if args.multiple_choice: + choices = d['choices'] + probs = [train_freq.get(c, 0) for c in choices] + if probs == [0, 0, 0, 0]: + probs = [1, 1, 1, 1] + probs = [p / sum(probs) for p in probs] + + q = d['question_id'] + predictions[q] = np.random.choice(choices, size=1, p=probs)[0] + +json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/load_aokvqa.py b/models/common/vqa_tools/aokvqa/load_aokvqa.py new file mode 100644 index 0000000..3e3dd49 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/load_aokvqa.py @@ -0,0 +1,13 @@ +import os +import json + + +def load_aokvqa(aokvqa_dir, split, version='v1p0'): + assert split in ['train', 'val', 'test', 'test_w_ans'] + dataset = json.load(open( + os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json") + )) + return dataset + +def get_coco_path(split, image_id, coco_dir): + return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg") diff --git a/models/common/vqa_tools/aokvqa/transfer_experiments/README.md b/models/common/vqa_tools/aokvqa/transfer_experiments/README.md new file mode 100644 index 0000000..dc5138d --- /dev/null +++ b/models/common/vqa_tools/aokvqa/transfer_experiments/README.md @@ -0,0 +1,41 @@ +## Transfer Learning Experiments + +We use the following training/prediction scripts for the classifier, zero-shot, and contrastive experiments in Table 3. + +```bash +## Training +python transfer_experiments/train.py --aokvqa-dir ${AOKVQA_DIR} --vocab ${AOKVQA_DIR}/large_vocab_train.csv --log-dir ${LOG_DIR} + +--backbone clip --clip-model-type ViT-B/32 --train-features ${FEATURES_DIR}/clip-ViT-B-32_train.pt --val-features ${FEATURES_DIR}/clip-ViT-B-32_val.pt +--inputs question # OR --inputs image # OR --inputs question image +# OR +--backbone resnet --train-features ${FEATURES_DIR}/resnet_train.pt --val-features ${FEATURES_DIR}/resnet_val.pt --inputs image +# OR +--backbone bert --train-features ${FEATURES_DIR}/bert_train.pt --val-features ${FEATURES_DIR}/bert_val.pt --inputs question + +--objective classifier +# OR +--objective contrastive --vocab-features ${FEATURE_DIR}/clip-ViT-B-32_large_vocab.pt +``` + +You can make predictions for CLIP zero-shot or from a classifier/contrastive checkpoint trained above. + +```bash +## Predicting +python transfer_experiments/predict.py --aokvqa-dir ${AOKVQA_DIR} --out ${PREDS_DIR}/clip-classifier_val-mc.json + +--split val # or test +--features ${FEATURE_DIR}/clip-ViT-B-32_val.pt # adjust for backbone and eval split + +--ckpt path/to/model.ckpt +# OR +--zero-shot --clip-model-type ViT-B/32 +--inputs question # OR --inputs image # OR --inputs question image + +--mc # Multiple-choice. Exclude for direct-answer. + +# IF classifier OR direct-answer +--vocab ${AOKVQA_DIR}/large_vocab_train.csv +# IF contrastive/zero-shot AND direct-answer +--vocab-features ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt +``` diff --git a/models/common/vqa_tools/aokvqa/transfer_experiments/predict.py b/models/common/vqa_tools/aokvqa/transfer_experiments/predict.py new file mode 100644 index 0000000..d2fbb42 --- /dev/null +++ b/models/common/vqa_tools/aokvqa/transfer_experiments/predict.py @@ -0,0 +1,126 @@ +import sys +import os +import argparse +import pathlib +from tqdm import tqdm +import json + +import torch +import torch.nn as nn + +# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 +import sentencepiece; import pytorch_lightning as pl; import clip + +from transfer_experiments.train import LinearClassifier +from load_aokvqa import load_aokvqa +from evaluation.remap_predictions import map_to_choices + + +parser = argparse.ArgumentParser() +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--features', type=pathlib.Path, required=True) +parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file') +# +parser_weights = parser.add_mutually_exclusive_group(required=True) + +parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path') + +parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot') +parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv)) +# +parser.add_argument('--vocab', type=argparse.FileType('r')) +parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features') +parser.add_argument('--mc', action='store_true', dest='multiple_choice') + +parser.add_argument('--clip-model-type', type=str, + choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], + dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv)) +# +args = parser.parse_args() + + +## Load dataset + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) + +## Load models + +device = "cuda" if torch.cuda.is_available() else "cpu" + +if args.checkpoint_path is not None: + classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path) + classifier.to(device) + hp = classifier.hparams +elif args.clip_zero_shot: + classifier = nn.Identity().to(device) + hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs) + +# Load input features + +embeddings = torch.load(args.features) +if hp.backbone == 'clip': + for q in embeddings.keys(): + embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True) + embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True) + +# Load vocab, vocab features, clip + +if (hp.objective == 'classifier') or \ + (hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False): + vocab = args.vocab.read().splitlines() + +if hp.objective in ['contrastive', 'zero-shot']: + if args.multiple_choice is False: + vocab_features = torch.load(args.vocab_features).cpu() + vocab_features /= vocab_features.norm(dim=-1, keepdim=True) + else: + clip_model = clip.load(hp.clip_model_type, device=device)[0] + logit_scale = clip_model.logit_scale.exp().cpu() + +## Prediction loop + +predictions = {} + +with torch.no_grad(): + for o in tqdm(aokvqa_set): + q = o['question_id'] + + # Load input embedding (from question / image) + if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs): + e = embeddings[q]['question'] + embeddings[q]['image'] + elif 'question' in hp.inputs and 'image' in hp.inputs: + e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) + elif 'question' in hp.inputs: + e = embeddings[q]['question'] + elif 'image' in hp.inputs: + e = embeddings[q]['image'] + + # Pass inputs through model + e = e.unsqueeze(0).to(device) + x = classifier(e)[0].cpu() + + # Predict + if hp.objective in ['contrastive', 'zero-shot']: + if args.multiple_choice: + vocab = o['choices'] + # Encode choices + vocab_features = clip.tokenize(vocab).to(device) + vocab_features = torch.stack([ + clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features + ], dim=1)[0] + vocab_features /= vocab_features.norm(dim=-1, keepdim=True) + vocab_features = vocab_features.float().cpu() + + x = logit_scale * x @ vocab_features.t() + x = x.softmax(dim=-1) + + predictions[q] = vocab[x.argmax().item()] + +## Save and evaluate predictions + +# Map prediction to nearest neighbor choice (by word embeddings) +if args.multiple_choice and hp.objective == 'classifier': + predictions = map_to_choices(aokvqa_set, predictions) + +json.dump(predictions, args.output_file) diff --git a/models/common/vqa_tools/aokvqa/transfer_experiments/train.py b/models/common/vqa_tools/aokvqa/transfer_experiments/train.py new file mode 100644 index 0000000..ac48b5a --- /dev/null +++ b/models/common/vqa_tools/aokvqa/transfer_experiments/train.py @@ -0,0 +1,263 @@ +import os +import sys +import json +import argparse +import pathlib +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 +import sentencepiece; import pytorch_lightning as pl + +import torchmetrics.functional as MF + +from load_aokvqa import load_aokvqa + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--vocab', type=argparse.FileType('r'), required=True) + parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True) + # + parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True) + parser.add_argument('--clip-model-type', type=str, + choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], + dest='clip_model_type', required=('clip' in sys.argv)) + parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features') + parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features') + parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features') + # + parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True) + parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True) + # Defaults + parser.add_argument('--bs', type=int, default=128, dest='batch_size') + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--epochs', type=int, default=500) + parser.add_argument('--gpus', type=int, default=1) + args = parser.parse_args() + + pl.seed_everything(1) + vocab = args.vocab.read().splitlines() + + ## Data loading + + dm = AokvqaEmbeddingsDataModule( + args.aokvqa_dir, + args.train_features, + args.val_features, + args.objective, + args.backbone, + args.inputs, + vocab, + args.vocab_features, + batch_size=args.batch_size, + num_workers=16 + ) + + ## Model definition + + model = LinearClassifier( + args.objective, + args.backbone, + args.clip_model_type, + args.inputs, + len(vocab), + args.lr + ) + + ## Training and testing loops + + logger = pl.loggers.TensorBoardLogger( + args.log_dir, + name=f'{args.backbone}-{args.objective}', + version=f"inputs:{'+'.join(args.inputs)}" + ) + + trainer = pl.Trainer( + logger=logger, + gpus=args.gpus, + max_epochs=args.epochs, + callbacks=[ + pl.callbacks.ModelCheckpoint( + monitor="val_acc", + filename="{epoch:02d}-{val_acc:.2f}", + mode="max" + ) + ], + ) + + trainer.fit(model, dm) + + +class AokvqaEmbeddingsDataset(Dataset): + def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features): + + aokvqa_set = load_aokvqa(aokvqa_dir, split) + + assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \ + or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \ + or ( backbone == 'clip' ) + + embeddings = torch.load(input_features) + if backbone == 'clip': + for q in embeddings.keys(): + embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True) + embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True) + if objective == 'contrastive': + vocab_embeddings = torch.load(vocab_features) + vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True) + + self.objective = objective + self.vocab_len = len(vocab) + + self.embeddings = [] + self.answers = [] + + for o in aokvqa_set: + correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers']) + correct_answers = [vocab.index(a) for a in correct_answers if a in vocab] + if self.objective == 'contrastive': + correct_answers = [vocab_embeddings[a] for a in correct_answers] + if len(correct_answers) == 0: continue + self.answers.append(correct_answers) + + q = o['question_id'] + if 'question' in inputs and 'image' in inputs: + e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) + elif 'question' in inputs and 'image' not in inputs: + e = embeddings[q]['question'] + elif 'question' not in inputs and 'image' in inputs: + e = embeddings[q]['image'] + self.embeddings.append(e) + + def __getitem__(self, index): + e = self.embeddings[index] + a = self.answers[index] + if self.objective == 'classifier': + a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0) + elif self.objective == 'contrastive': + a = random.sample(a, 1)[0] + return e, a + + def __len__(self): + return len(self.embeddings) + + +class AokvqaEmbeddingsDataModule(pl.LightningDataModule): + + def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0): + super().__init__() + self.aokvqa_dir = aokvqa_dir + self.train_features = train_features + self.val_features = val_features + self.objective = objective + self.backbone = backbone + self.inputs = inputs + self.vocab = vocab + self.vocab_features = vocab_features + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage=None): + self.train_dataset = AokvqaEmbeddingsDataset( + self.aokvqa_dir, 'train', self.train_features, self.objective, + self.backbone, self.inputs, self.vocab, self.vocab_features + ) + self.val_dataset = AokvqaEmbeddingsDataset( + self.aokvqa_dir, 'val', self.val_features, self.objective, + self.backbone, self.inputs, self.vocab, self.vocab_features + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=int(0.8 * self.num_workers) + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, shuffle=False, + num_workers=int(0.2 * self.num_workers) + ) + + +class LinearClassifier(pl.LightningModule): + def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001): + super().__init__() + self.save_hyperparameters(ignore=['lr']) + self.lr = lr + + if self.hparams.backbone == 'clip': + clip_dim = { + 'RN50' : 1024, + 'RN50x4' : 640, + 'RN50x16' : 768, + 'RN50x64' : 1024, + 'RN101' : 512, + 'ViT-B/32' : 512, + 'ViT-B/16' : 512, + 'ViT-L/14' : 768, + 'ViT-L/14@336px' : 768, + }[clip_model_type] + emb_dim = clip_dim * len(inputs) + elif self.hparams.backbone == 'resnet': + emb_dim = 2048 + elif self.hparams.backbone == 'bert': + emb_dim = 768 + + if self.hparams.objective == 'classifier': + out_dim = vocab_len + elif self.hparams.objective == 'contrastive': + out_dim = clip_dim + + self.linear = nn.Linear(emb_dim, out_dim) + + def forward(self, x): + x = self.linear(x) + if self.hparams.objective == 'classifier': + x = torch.sigmoid(x) + return x + + def compute_loss(self, batch): + x, y = batch + + y_pred = self.forward(x) + + if self.hparams.objective == 'classifier': + loss = F.binary_cross_entropy(y_pred, y.float()) + elif self.hparams.objective == 'contrastive': + indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device) + sim = (y_pred @ y.T).softmax(dim=-1) + loss = F.cross_entropy(sim, indices) + + if self.hparams.objective == 'classifier': + acc = MF.f1_score(y_pred, y) + elif self.hparams.objective == 'contrastive': + acc = torch.mean(sim[indices, indices]) + + return loss, acc + + def training_step(self, batch, batch_idx): + loss, acc = self.compute_loss(batch) + self.log("train_loss", loss) + self.log("train_acc", acc) + return loss + + def validation_step(self, batch, batch_idx): + loss, acc = self.compute_loss(batch) + self.log("val_loss", loss) + self.log("val_acc", acc) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + +if __name__ == '__main__': + main() diff --git a/models/common/vqa_tools/vqa.py b/models/common/vqa_tools/vqa.py new file mode 100644 index 0000000..a386b90 --- /dev/null +++ b/models/common/vqa_tools/vqa.py @@ -0,0 +1,211 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" +__version__ = "0.9" + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print("loading VQA annotations and questions into memory...") + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, "r")) + questions = json.load(open(question_file, "r")) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print("creating index...") + imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} + qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + for ann in self.dataset["annotations"]: + imgToQA[ann["image_id"]] += [ann] + qa[ann["question_id"]] = ann + for ques in self.questions["questions"]: + qqa[ques["question_id"]] = ques + print("index created!") + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + for key, value in self.datset["info"].items(): + print("%s: %s" % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(imgIds) == 0: + anns = sum( + [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], + [], + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["question_id"] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(quesIds) == 0: + anns = sum( + [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["image_id"] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann["question_id"] + print("Question: %s" % (self.qqa[quesId]["question"])) + for ans in ann["answers"]: + print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset["info"] = copy.deepcopy(self.questions["info"]) + res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) + res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) + res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) + res.dataset["license"] = copy.deepcopy(self.questions["license"]) + + print("Loading and preparing results... ") + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, "results is not an array of objects" + annsQuesIds = [ann["question_id"] for ann in anns] + assert set(annsQuesIds) == set( + self.getQuesIds() + ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." + for ann in anns: + quesId = ann["question_id"] + if res.dataset["task_type"] == "Multiple Choice": + assert ( + ann["answer"] in self.qqa[quesId]["multiple_choices"] + ), "predicted answer is not one of the multiple choices" + qaAnn = self.qa[quesId] + ann["image_id"] = qaAnn["image_id"] + ann["question_type"] = qaAnn["question_type"] + ann["answer_type"] = qaAnn["answer_type"] + print( + "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) + ) + + res.dataset["annotations"] = anns + res.createIndex() + return res diff --git a/models/common/vqa_tools/vqa_eval.py b/models/common/vqa_tools/vqa_eval.py new file mode 100644 index 0000000..ee808b3 --- /dev/null +++ b/models/common/vqa_tools/vqa_eval.py @@ -0,0 +1,324 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = "aagrawal" + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys +import re + + +class VQAEval: + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + if vqa is not None: + self.params = {"question_id": vqa.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print("computing accuracy") + step = 0 + for quesId in quesIds: + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]["question_type"] + ansType = gts[quesId]["answer_type"] + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print("Done computing accuracy") + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/models/criteria.py b/models/criteria.py new file mode 100644 index 0000000..02a7726 --- /dev/null +++ b/models/criteria.py @@ -0,0 +1,654 @@ +from functools import lru_cache + +import torch +import torch.nn.functional as F +from torch import nn + +from models.utils import allgather_wgrad +from utils.dist import get_rank, get_world_size +from utils.easydict import EasyDict + + +def get_sim( + x_proj: torch.Tensor, + y_proj: torch.Tensor, + temp=1.0, +): + """calculate pair-wise similarity between two modalities x and y. + + Args: + x_proj (torch.Tensor): The representation of modality x. Shape: [B,T,C] or [B,C]. + y_proj (torch.Tensor): The representation of modality y. Shape: [B,C]. + temp (torch.Tensor): The temperature. Shape: []. + + Returns: The similarity between modality x and y. Shape: [B,B]. + + """ + x_proj = F.normalize(x_proj, dim=-1) + y_proj = F.normalize(y_proj, dim=-1) + assert x_proj.dim() in [2, 3] + assert y_proj.dim() == 2 + if x_proj.dim() == 2: + sim_x2y = torch.einsum("md,nd->mn", x_proj, y_proj) / temp # (B,B) + else: + sim_x2y = torch.einsum("mld,nd->mln", x_proj, y_proj).mean(1) / temp # (B,B) + sim_y2x = sim_x2y.T + return sim_x2y, sim_y2x + + +class ContMatchLoss(nn.Module): + def __init__(self): + super(ContMatchLoss, self).__init__() + + @torch.no_grad() + def get_mask(self, sim, idx=None, normalize=False): + """ + Args: + sim (torch.Tensor): The similarity between videos and texts. shape: (B, B). + idx (torch.Tensor): The index for each video. Shape: [B]. + normalize (bool): If true, make row sum equal to 1 + """ + if idx is not None: + idx = idx.view(-1, 1) + mask = torch.eq(idx, idx.T).to(sim.dtype) + if normalize: + mask = mask / mask.sum(1, keepdim=True) + else: + mask = torch.zeros_like(sim) + mask.fill_diagonal_(1) + return mask # `1` mark valid/matched location + + @lru_cache(maxsize=16) + def get_gather_args(self): + """obtain the args for all_gather + Returns: dict. + + """ + return EasyDict({"world_size": get_world_size(), "rank": get_rank()}) + + +class STC_STM_Loss(ContMatchLoss): + """Contrastive and matching losses""" + + def __init__(self): + super(STC_STM_Loss, self).__init__() + + def stc_loss( + self, + temporal_proj: torch.Tensor, + spatial_proj: torch.Tensor, + idx: torch.Tensor, + temp=1.0, + all_gather=True + ): + + """forward to calculate the loss + + Args: + vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. + text_proj (torch.Tensor): The text representation. Shape: [B,C]. + idx (torch.Tensor): The index for each example. Shape: [B,]. + temp (torch.Tensor): The temperature. Shape: []. + all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. + + Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. + + """ + if all_gather: + gather_args = self.get_gather_args() + temporal_proj = allgather_wgrad(temporal_proj, gather_args) + spatial_proj = allgather_wgrad(spatial_proj, gather_args) + if idx is not None: + idx = allgather_wgrad(idx, gather_args) + + sim_t2s, sim_s2t = get_sim(temporal_proj, spatial_proj, temp) + + with torch.no_grad(): + sim_t2s_targets = self.get_mask(sim_t2s, idx=idx, normalize=True) + sim_s2t_targets = sim_t2s_targets + + loss_t2s = -torch.sum(F.log_softmax(sim_t2s, dim=1) * sim_t2s_targets, dim=1).mean() + loss_s2t = -torch.sum(F.log_softmax(sim_s2t, dim=1) * sim_s2t_targets, dim=1).mean() + + loss_stc = (loss_t2s + loss_s2t) / 2 + return loss_stc + + def stm_loss( + self, + grounding_expert, + stm_head, + # temp, + spatial_embeds_orig, + temporal_embeds_orig, + temporal_proj, + spatial_proj, + idx, + generation=False, + temp=1.0 + ): + spatial_embeds = spatial_embeds_orig.clone() + temporal_embeds = temporal_embeds_orig.clone() + with torch.no_grad(): + sim_s2t, sim_t2s = get_sim(temporal_proj, spatial_proj, temp) + spatial_atts = torch.ones( + spatial_embeds.size()[:-1], dtype=torch.long, device=spatial_embeds.device + ) + temporal_atts = torch.ones( + temporal_embeds.size()[:-1], dtype=torch.long, device=temporal_embeds.device + ) + weights_s2t = F.softmax(sim_s2t + 1e-4, dim=1) # (N, N) + weights_t2s = F.softmax(sim_t2s + 1e-4, dim=1) + + mask = self.get_mask(sim_s2t, idx=idx).bool() + weights_s2t.masked_fill_(mask, 0) + weights_t2s.masked_fill_(mask, 0) + weights_s2t = torch.nan_to_num_(weights_s2t, nan=1e-2, posinf=1e-2, neginf=1e-2) + weights_t2s = torch.nan_to_num_(weights_t2s, nan=1e-2, posinf=1e-2, neginf=1e-2) + + if generation: + with torch.no_grad(): + output = grounding_expert( + encoder_embeds=temporal_embeds, + attention_mask=temporal_atts, + encoder_hidden_states=spatial_embeds, + encoder_attention_mask=spatial_atts, + return_dict=True, + ) + pos_feats = output.last_hidden_state + return pos_feats + + else: + # select a hard negatives within the batch + spatial_neg_indices = torch.multinomial(weights_s2t, 1).squeeze() + temporal_neg_indices = torch.multinomial(weights_t2s, 1).squeeze() + + + spatial_embeds_neg = spatial_embeds[spatial_neg_indices] # [B, L, c] + temporal_embeds_neg = temporal_embeds[temporal_neg_indices] # [B, L, d] + # temporal_atts_neg = temporal_atts[temporal_neg_indices] + + # concat embeddings + spatial_embeds_all = torch.cat([spatial_embeds, spatial_embeds_neg, spatial_embeds], dim=0) + temporal_embeds_all = torch.cat([temporal_embeds, temporal_embeds, temporal_embeds_neg], dim=0) + spatial_atts_all = torch.cat([spatial_atts, spatial_atts, spatial_atts], dim=0) + temporal_atts_all = torch.cat([temporal_atts, temporal_atts, temporal_atts], dim=0) + + output = grounding_expert( + inputs_embeds=temporal_embeds_all, + attention_mask=temporal_atts_all, + cross_embeds=spatial_embeds_all, + cross_attention_mask=spatial_atts_all, + ) + + stm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) + + stm_logits = stm_head(stm_embeds) # [3*B, 2] + + bs = stm_logits.shape[0] // 3 + stm_labels = stm_logits.new_ones(3 * bs, dtype=torch.long) + stm_labels[bs:] = 0 + loss_stm = F.cross_entropy(stm_logits, stm_labels) + pos_feats = output.last_hidden_state[:bs] + + return loss_stm, pos_feats + + +class VCC_VCM_Loss(ContMatchLoss): + """Contrastive and matching losses""" + + def __init__(self): + super(VCC_VCM_Loss, self).__init__() + + def vcc_loss( + self, + vis_proj: torch.Tensor, + cap_proj: torch.Tensor, + idx: torch.Tensor, + temp=1.0, + all_gather=True + ): + + """forward to calculate the loss + + Args: + vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. + text_proj (torch.Tensor): The text representation. Shape: [B,C]. + idx (torch.Tensor): The index for each example. Shape: [B,]. + temp (torch.Tensor): The temperature. Shape: []. + all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. + + Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. + + """ + if all_gather: + gather_args = self.get_gather_args() + vis_proj = allgather_wgrad(vis_proj, gather_args) + cap_proj = allgather_wgrad(cap_proj, gather_args) + if idx is not None: + idx = allgather_wgrad(idx, gather_args) + + sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp) + + with torch.no_grad(): + sim_v2c_targets = self.get_mask(sim_v2c, idx=idx, normalize=True) + sim_c2v_targets = sim_v2c_targets + + loss_v2c = -torch.sum(F.log_softmax(sim_v2c, dim=1) * sim_v2c_targets, dim=1).mean() + loss_c2v = -torch.sum(F.log_softmax(sim_c2v, dim=1) * sim_c2v_targets, dim=1).mean() + + loss_vcc = (loss_v2c + loss_c2v) / 2 + return loss_vcc + + def vcm_loss( + self, + grounding_expert, + vcm_head, + vis_embeds_orig, + cap_embeds_orig, + vis_proj, + cap_proj, + cap_atts, + idx, + generation=False, + temp=1.0 + ): + vis_embeds = vis_embeds_orig.clone() + cap_embeds = cap_embeds_orig.clone() + + with torch.no_grad(): + sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp) + vis_atts = torch.ones( + vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device + ) + + weights_v2c = F.softmax(sim_v2c + 1e-4, dim=1) # (N, N) + weights_c2v = F.softmax(sim_c2v + 1e-4, dim=1) + + mask = self.get_mask(weights_v2c, idx=idx).bool() + weights_v2c.masked_fill_(mask, 0) + weights_c2v.masked_fill_(mask, 0) + weights_v2c = torch.nan_to_num_(weights_v2c, nan=1e-2, posinf=1e-2, neginf=1e-2) + weights_c2v = torch.nan_to_num_(weights_c2v, nan=1e-2, posinf=1e-2, neginf=1e-2) + + if generation: + with torch.no_grad(): + output = grounding_expert( + encoder_embeds=cap_embeds, + attention_mask=cap_atts, + encoder_hidden_states=vis_embeds, + encoder_attention_mask=vis_atts, + return_dict=True, + ) + pos_feats = output.last_hidden_state + return pos_feats + + else: + + + # select a hard negatives within the batch + vis_neg_indices = torch.multinomial(weights_v2c, 1).squeeze() + cap_neg_indices = torch.multinomial(weights_c2v, 1).squeeze() + + + vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c] + cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, d] + cap_atts_neg = cap_atts[cap_neg_indices] + + # concat embeddings + vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0) + cap_embeds_all = torch.cat([cap_embeds, cap_embeds, cap_embeds_neg], dim=0) + vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0) + cap_atts_all = torch.cat([cap_atts, cap_atts, cap_atts_neg], dim=0) + + output = grounding_expert( + inputs_embeds=cap_embeds_all, + attention_mask=cap_atts_all, + cross_embeds=vis_embeds_all, + cross_attention_mask=vis_atts_all, + ) + + vcm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) + + vcm_logits = vcm_head(vcm_embeds) # [3*B, 2] + + bs = vcm_logits.shape[0] // 3 + vcm_labels = vcm_logits.new_ones(3 * bs, dtype=torch.long) + vcm_labels[bs:] = 0 + loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) + pos_feats = output.last_hidden_state[:bs] + return loss_vcm, pos_feats + + +class VHC_VHM_Loss(ContMatchLoss): + """Contrastive and matching losses""" + + def __init__(self): + super(VHC_VHM_Loss, self).__init__() + + def vhc_loss( + self, + vis_proj: torch.Tensor, + hist_proj: torch.Tensor, + idx: torch.Tensor, + temp=1.0, + all_gather=True + ): + + """forward to calculate the loss + + Args: + vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. + text_proj (torch.Tensor): The text representation. Shape: [B,C]. + idx (torch.Tensor): The index for each example. Shape: [B,]. + temp (torch.Tensor): The temperature. Shape: []. + all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. + + Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. + + """ + if all_gather: + gather_args = self.get_gather_args() + vis_proj = allgather_wgrad(vis_proj, gather_args) + hist_proj = allgather_wgrad(hist_proj, gather_args) + if idx is not None: + idx = allgather_wgrad(idx, gather_args) + + sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp) + + with torch.no_grad(): + sim_v2h_targets = self.get_mask(sim_v2h, idx=idx, normalize=True) + sim_h2v_targets = sim_v2h_targets + + loss_v2h = -torch.sum(F.log_softmax(sim_v2h, dim=1) * sim_v2h_targets, dim=1).mean() + loss_h2v = -torch.sum(F.log_softmax(sim_h2v, dim=1) * sim_h2v_targets, dim=1).mean() + + loss_vhc = (loss_v2h + loss_h2v) / 2 + return loss_vhc + + def vhm_loss( + self, + grounding_expert, + vhm_head, + vis_embeds_orig, + hist_embeds_orig, + vis_proj, + hist_proj, + hist_atts, + idx, + generation=False, + temp=1.0, + ): + vis_embeds = vis_embeds_orig.clone() + hist_embeds = hist_embeds_orig.clone() + with torch.no_grad(): + sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp) + vis_atts = torch.ones( + vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device + ) + + weights_v2h = F.softmax(sim_v2h + 1e-4, dim=1) # (N, N) + weights_h2v = F.softmax(sim_h2v + 1e-4, dim=1) + + mask = self.get_mask(weights_v2h, idx=idx).bool() + weights_v2h.masked_fill_(mask, 0) + weights_h2v.masked_fill_(mask, 0) + weights_v2h = torch.nan_to_num_(weights_v2h, nan=1e-2, posinf=1e-2, neginf=1e-2) + weights_h2v = torch.nan_to_num_(weights_h2v, nan=1e-2, posinf=1e-2, neginf=1e-2) + + if generation: + with torch.no_grad(): + output = grounding_expert( + encoder_embeds=hist_embeds, + attention_mask=hist_atts, + encoder_hidden_states=vis_embeds, + encoder_attention_mask=vis_atts, + return_dict=True, + # mode="fusion", + ) + pos_feats = output.last_hidden_state + return pos_feats + + else: + # select a hard negatives within the batch + vis_neg_indices = torch.multinomial(weights_v2h, 1).squeeze() + hist_neg_indices = torch.multinomial(weights_h2v, 1).squeeze() + + vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c] + hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d] + hist_atts_neg = hist_atts[hist_neg_indices] + + # concat embeddings + vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0) + hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0) + vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0) + hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0) + + output = grounding_expert( + inputs_embeds=hist_embeds_all, + attention_mask=hist_atts_all, + cross_embeds=vis_embeds_all, + cross_attention_mask=vis_atts_all, + ) + + vhm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) + + vhm_logits = vhm_head(vhm_embeds) # [3*B, 2] + + bs = vhm_logits.shape[0] // 3 + vhm_labels = vhm_logits.new_ones(3 * bs, dtype=torch.long) + vhm_labels[bs:] = 0 + loss_vhm = F.cross_entropy(vhm_logits, vhm_labels) + pos_feats = output.last_hidden_state[:bs] + + return loss_vhm, pos_feats + + +class CHC_CHM_Loss(ContMatchLoss): + """Contrastive and matching losses""" + + def __init__(self): + super(CHC_CHM_Loss, self).__init__() + + def chc_loss( + self, + cap_proj: torch.Tensor, + hist_proj: torch.Tensor, + idx: torch.Tensor, + temp=1.0, + all_gather=True + ): + + """forward to calculate the loss + + Args: + vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. + text_proj (torch.Tensor): The text representation. Shape: [B,C]. + idx (torch.Tensor): The index for each example. Shape: [B,]. + temp (torch.Tensor): The temperature. Shape: []. + all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. + + Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. + + """ + if all_gather: + gather_args = self.get_gather_args() + cap_proj = allgather_wgrad(cap_proj, gather_args) + hist_proj = allgather_wgrad(hist_proj, gather_args) + if idx is not None: + idx = allgather_wgrad(idx, gather_args) + + sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp) + + with torch.no_grad(): + sim_c2h_targets = self.get_mask(sim_c2h, idx=idx, normalize=True) + sim_h2c_targets = sim_c2h_targets + + loss_c2h = -torch.sum(F.log_softmax(sim_c2h, dim=1) * sim_c2h_targets, dim=1).mean() + loss_h2c = -torch.sum(F.log_softmax(sim_h2c, dim=1) * sim_h2c_targets, dim=1).mean() + + loss_chc = (loss_c2h + loss_h2c) / 2 + return loss_chc + + def chm_loss( + self, + grounding_expert, + chm_head, + cap_embeds_orig, + hist_embeds_orig, + cap_proj, + hist_proj, + cap_atts, + hist_atts, + idx, + generation=False, + temp=1.0 + ): + cap_embeds = cap_embeds_orig.clone() + hist_embeds = hist_embeds_orig.clone() + with torch.no_grad(): + sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp) + + weights_c2h = F.softmax(sim_c2h + 1e-4, dim=1) # (N, N) + weights_h2c = F.softmax(sim_h2c + 1e-4, dim=1) + + mask = self.get_mask(weights_c2h, idx=idx).bool() + weights_c2h.masked_fill_(mask, 0) + weights_h2c.masked_fill_(mask, 0) + weights_c2h = torch.nan_to_num_(weights_c2h, nan=1e-2, posinf=1e-2, neginf=1e-2) + weights_h2c = torch.nan_to_num_(weights_h2c, nan=1e-2, posinf=1e-2, neginf=1e-2) + + if generation: + with torch.no_grad(): + output = grounding_expert( + encoder_embeds=hist_embeds, + attention_mask=hist_atts, + encoder_hidden_states=cap_embeds, + encoder_attention_mask=cap_atts, + return_dict=True, + ) + pos_feats = output.last_hidden_state + return pos_feats + else: + # select a hard negatives within the batch + cap_neg_indices = torch.multinomial(weights_c2h, 1).squeeze() + hist_neg_indices = torch.multinomial(weights_h2c, 1).squeeze() + + cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, c] + cap_atts_neg = cap_atts[cap_neg_indices] + hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d] + hist_atts_neg = hist_atts[hist_neg_indices] + + # concat embeddings + cap_embeds_all = torch.cat([cap_embeds, cap_embeds_neg, cap_embeds], dim=0) + hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0) + cap_atts_all = torch.cat([cap_atts, cap_atts_neg, cap_atts], dim=0) + hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0) + + output = grounding_expert( + inputs_embeds=hist_embeds_all, + attention_mask=hist_atts_all, + cross_embeds=cap_embeds_all, + cross_attention_mask=cap_atts_all, + ) + + chm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) + + chm_logits = chm_head(chm_embeds) # [3*B, 2] + + bs = chm_logits.shape[0] // 3 + chm_labels = chm_logits.new_ones(3 * bs, dtype=torch.long) + chm_labels[bs:] = 0 + loss_chm = F.cross_entropy(chm_logits, chm_labels) + pos_feats = output.last_hidden_state[:bs] + return loss_chm, pos_feats + + +class MLMLoss(nn.Module): + """masked language modeling loss.""" + + def __init__(self, masking_prob, tokenizer): + super(MLMLoss, self).__init__() + self.tokenizer = tokenizer + self.masking_prob = masking_prob + + def mlm_loss( + self, + text_encoder, + text, + text_embeds, + vision_embeds, + vision_atts, + ): + input_ids = text.input_ids.clone() + labels = input_ids.clone() + probability_matrix = torch.full(labels.shape, self.masking_prob) + input_ids, labels = self.mask( + input_ids, + text_encoder.config.vocab_size, + input_ids.device, + targets=labels, + probability_matrix=probability_matrix, + ) + + # intermediate_mlm_output = text_encoder.bert( + # input_ids, + # attention_mask=text.attention_mask, + # encoder_hidden_states=vision_embeds, + # encoder_attention_mask=vision_atts, + # return_dict=True, + # # mode="text", + # ) + + # text_embeds = intermediate_mlm_output.last_hidden_state + + mlm_output = text_encoder( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=vision_embeds, + encoder_attention_mask=vision_atts, + return_dict=True, + labels=labels, + soft_labels=None, + # mode="fusion", + ) + return mlm_output.loss + + def mask( + self, + input_ids, + vocab_size, + device, + targets=None, + masked_indices=None, + probability_matrix=None, + ): + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + # We only compute loss on masked tokens + targets[~masked_indices] = -100 + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = ( + torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + ) + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids diff --git a/models/modules/__init__.py b/models/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/modules/temporal_modelling.py b/models/modules/temporal_modelling.py new file mode 100644 index 0000000..55f2d50 --- /dev/null +++ b/models/modules/temporal_modelling.py @@ -0,0 +1,286 @@ +import logging +import math + +import einops +import torch +from einops import rearrange +from timm.models.layers.drop import DropPath +from torch import nn +from torch.nn import LayerNorm, Linear, MultiheadAttention + +logger = logging.getLogger(__name__) + + +class STAdapter(nn.Module): + """ST Adapter""" + + def __init__( + self, + kernel_size=(3, 3, 3), + input_dim=768, + hidden_dim=384, + img_size=224, + patch_size=16, + drop_prob=0.1, + ): + super(STAdapter, self).__init__() + self.kernel_size = kernel_size + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.h = self.w = img_size // patch_size + + self.linear1 = nn.Linear(input_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, input_dim) + self.act = nn.ReLU() + self.conv = nn.Conv3d( + hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim + ) + self.droppath = DropPath(drop_prob=drop_prob) + + self.scale = nn.parameter.Parameter(torch.zeros([])) + + def forward(self, x: torch.Tensor): + """forward + + Args: + x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + + shortcut = x + x = self.linear1(x) + cls = x[:, :, :1, :] + tokens = x[:, :, 1:, :] + tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous() + tokens = self.conv(tokens) + tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c") + x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c] + x = self.act(x) + x = self.linear2(x) + + return shortcut + self.scale * self.droppath(x) + + +class SpatialAttention(nn.Module): + """Perfrom spatial self-attention""" + + def __init__(self, input_dim=768, droppath_rate=0.1): + super(SpatialAttention, self).__init__() + self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True) + self.norm = LayerNorm(input_dim, eps=1e-12) + self.linear = Linear(input_dim, input_dim) + self.droppath = DropPath(droppath_rate) + # self.scale = nn.parameter.Parameter(torch.zeros([])) + self.scale = 1.0 + + def forward(self, x: torch.Tensor): + if x.shape[1] == 1: + x = self.norm(x) + x = einops.rearrange(x, "b t l c -> b (t l) c") + return x # return self if media is image + + shortcut = x + x = einops.rearrange(x, 'b t l c -> (b t) l c') + x = self.norm(x) + x = self.attn(x, x, x)[0] + x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0]) + x = shortcut + self.scale * self.droppath(x) + x = einops.rearrange(x, "b t l c -> b (t l) c") + return x + + +class TemporalAttention(nn.Module): + + """perform temporal self-attention""" + + def __init__(self, input_dim=768, droppath_rate=0.1): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super(TemporalAttention, self).__init__() + + self._input_dim = input_dim + self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True) + self.norm = LayerNorm(input_dim, eps=1e-12) + self.linear = Linear(input_dim, input_dim) + self.droppath = DropPath(droppath_rate) + # self.scale = nn.parameter.Parameter(torch.zeros([])) + self.scale = 1.0 + + def forward(self, x: torch.Tensor): + """forward + + Args: + x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + + """ + if x.shape[1] == 1: # for single frame, return itself. + x = self.norm(x) + x = einops.rearrange(x, "b t l c -> b (t l) c") + return x + + shortcut = x + x = einops.rearrange(x, "b t l c -> (b l) t c") + x = self.norm(x) + x = self.attn(x, x, x)[0] + x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0]) + x = shortcut + self.scale * self.droppath(x) + x = einops.rearrange(x, "b t l c -> b (t l) c") + return x + + +class WindowTemporalAttention(nn.Module): + + """perform windowed temporal self-attention""" + + def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + + self._input_dim = input_dim + self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64) + self.norm = LayerNorm(input_dim, eps=1e-12) + self.droppath = DropPath(droppath_rate) + self.scale = nn.parameter.Parameter(torch.zeros([])) + self.wh, self.ww = window_size + # logger.info(f"WindowTemporalAttention: window_size: {window_size}") + + def forward(self, x: torch.Tensor): + """forward + + Args: + x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + shortcut = x + + h = w = int(math.sqrt(x.shape[2] - 1)) + cls_token = x[:, :, :1, :] + x = einops.rearrange( + x[:, :, 1:, :], + "b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c", + nh=h // self.wh, + wh=self.wh, + nw=w // self.ww, + ww=self.ww, + ) + x = self.norm(x) + x = self.temporal_attn(x, x, x)[0] + x = einops.rearrange( + x, + "(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c", + wh=self.wh, + ww=self.ww, + nh=h // self.wh, + nw=w // self.ww, + ) + # add back cls token. + x = torch.concat([cls_token, x], dim=2) + return shortcut + self.scale * self.droppath(x) + + +class X_CLIP(nn.Module): + + """perform windowed temporal self-attention""" + + def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + + d_model = input_dim + + self.message_fc = nn.Linear(d_model, d_model) + self.message_ln = LayerNorm(d_model, eps=1e-12) + self.message_attn = nn.MultiheadAttention(d_model, d_model // 64) + self.num_prompts = num_prompts + + self.droppath = DropPath(droppath_rate) + + def forward(self, x: torch.Tensor): + """forward + + Args: + x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c] + msg_token = rearrange(msg_token, "b t c -> t b c") + msg_token = msg_token + self.droppath( + self.message_attn(msg_token, msg_token, msg_token)[0] + ) + msg_token = rearrange(msg_token, "t b c -> b t c") + # replace the last prompt token with msg_token. + x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c] + return x + + +class TemporalS4(nn.Module): + + """perform temporal self-attention""" + + def __init__(self, input_dim=768, droppath_rate=0.1): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + from .s4 import S4 + + self._input_dim = input_dim + self.norm = LayerNorm(input_dim, eps=1e-12) + self.droppath = DropPath(droppath_rate) + self.scale = nn.parameter.Parameter(torch.zeros([])) + self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True) + + def forward(self, x: torch.Tensor): + """forward + + Args: + x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + + shortcut = x + x = self.norm(x) + x = einops.rearrange(x, "b t l c -> b c (t l)") + x, _ = self.s4(x) + x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1]) + return shortcut + self.scale * self.droppath(x) diff --git a/models/setup.py b/models/setup.py new file mode 100644 index 0000000..95d5be2 --- /dev/null +++ b/models/setup.py @@ -0,0 +1,358 @@ +import copy +import os.path as osp +import glog as logger + +import torch +from torch.utils.data import ConcatDataset +from models.backbones.beit.builder import interpolate_pos_embed_beit +from models.backbones.bert.tokenization_bert import BertTokenizer +from transformers import T5Tokenizer, BartTokenizer, LlamaTokenizer +from utils.optimizer import create_optimizer +from utils.scheduler import create_scheduler +from datasets.dataloader import load_dataloaders +from datasets.pretraining import load_datasets as load_datasets_stage_1 +from datasets.visdial_dataset import load_visdial_dataset +from datasets.champagne_dataset import load_champagne_dataset +from datasets.nextqa_dataset import load_nextqa_dataset +from datasets.avsd_dataset import load_avsd_dataset +# from datasets.avsd_dataset_like_mixer import load_avsd_dataset + +from processors.blip_processors import Blip2ImageTrainProcessor +from processors.blip_processors import BlipCaptionProcessor, BlipDialogProcessor + +from utils.init import set_training_steps +# from models.v2dial import V2Dial, V2DialBase +from models.v2dial import V2DialBase, V2Dial, V2DialNoMoes + +# from datasets.avsd_dataset import get_dataset, AVSDDataSet +from torch.utils.data import DataLoader + + +def setup_model( + config, has_decoder=False, pretrain=False, find_unused_parameters=True +): + logger.info("Creating model") + + if config['stage'] == 'stage_1': + config = copy.deepcopy(config) + + # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + # model = V2DialBase(config=config, expert_tokenizer=tokenizer) + model = V2DialBase(config) + model = model.to(torch.device('cuda')) + model_without_ddp = model + optimizer = create_optimizer(config, model) + scheduler = create_scheduler(config, optimizer) + scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) + + if config['distributed']: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[config['gpu']], + find_unused_parameters=find_unused_parameters, # `False` for image-only task + ) + + start_epoch = 0 + global_step = 0 + webvid_step = 0 + cc3m_step = 0 + + if osp.isfile(config['pretrained_path']): + logger.info(f"Loading checkpoint from {config['pretrained_path']}") + checkpoint = torch.load(config['pretrained_path'], map_location="cpu") + state_dict = checkpoint["model"] + + if config.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + scaler.load_state_dict(checkpoint["scaler"]) + start_epoch = checkpoint["epoch"] + 1 + global_step = checkpoint["global_step"] + elif not pretrain: # downstream init from pretrained ckpt + + # interpolate positional embeddings. + state_dict = interpolate_pos_embed_beit(state_dict, model_without_ddp) + + + #TODO Might need to update to match the MoEs + if not config.evaluate: # finetuning from a pretarined weights. + for key in list(state_dict.keys()): + if "bert" in key: + encoder_key = key.replace("bert.", "") + state_dict[encoder_key] = state_dict[key] + if not has_decoder: + del state_dict[key] + + # init text decoder as multimodal encoder (last 6 layers of model.text_encoder) + # only for generation tasks like VQA + if has_decoder and "text_encoder" in key: + if "layer" in key: + encoder_keys = key.split(".") + layer_num = int(encoder_keys[4]) + if layer_num < config.model.text_encoder.fusion_layer: + del state_dict[key] + continue + else: + decoder_layer_num = layer_num - 9 + encoder_keys[4] = str(decoder_layer_num) + encoder_key = ".".join(encoder_keys) + else: + encoder_key = key + decoder_key = encoder_key.replace("text_encoder", "text_decoder") + state_dict[decoder_key] = state_dict[key] + del state_dict[key] + + msg = model_without_ddp.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"Loaded checkpoint from {config.pretrained_path}") + else: + logger.warning("No pretrained checkpoint provided, training from scratch") + + return ( + model, + model_without_ddp, + optimizer, + scheduler, + scaler, + start_epoch, + global_step, + webvid_step, + cc3m_step, + config + ) + else: + # config = copy.deepcopy(config) + # if config['use_original_feats']: + # model = AVSDBart(config) + # else: + # # model = V2Dial(config, tokenizer_experts, tokenizer_enc_dec) + # if config.use_moes: + model = V2Dial(config) + # else: + # model = V2DialNoMoes(config) + + model = model.to(torch.device('cuda')) + model_without_ddp = model + + optimizer = None + scheduler = None + scaler = None + + start_epoch = 0 + global_step = 0 + if config['stage'] == 'stage_3': + visdial_step = 0 + avsd_step = 0 + nextqa_step = 0 + + ckpt_path = config.pretrained_path_resume if config.resume else config.pretrained_path_prev_stage + if config.generating: + ckpt_path = config.best_ckpt_path + + if osp.isfile(ckpt_path): + logger.info(f"Loading checkpoint from {ckpt_path}") + checkpoint = torch.load(ckpt_path, map_location="cpu") + state_dict = checkpoint["model"] + + if config.resume: + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + scaler.load_state_dict(checkpoint["scaler"]) + start_epoch = checkpoint["epoch"] + 1 + global_step = checkpoint["global_step"] + if config['stage'] == 'stage_3': + visdial_step = checkpoint['visdial_step'] + avsd_step = checkpoint['avsd_step'] + next_step = checkpoint['nextqa_step'] + + + if config['stage'] in ['stage_2', 'stage_3'] and config.use_moes: + # Init. the history expert erights with the caption expert weights + p_names = [ + 'moe_layers.{}.norm_hist.weight', + 'moe_layers.{}.mlp_hist.fc1.weight', + 'moe_layers.{}.mlp_hist.fc1.bias', + 'moe_layers.{}.mlp_hist.fc2.weight', + 'moe_layers.{}.mlp_hist.fc2.bias', + ] + + for moe_layer_idx in range(config.num_moe_modality_layers): + for p_name in p_names: + p_hist_name = p_name.format(moe_layer_idx) + if p_hist_name not in state_dict: + p_cap_name = p_hist_name.replace('hist', 'cap') + state_dict[p_hist_name] = state_dict[p_cap_name].clone() + + msg = model_without_ddp.load_state_dict(state_dict, strict=False) + logger.info(msg) + + logger.info(f"Loaded checkpoint from {ckpt_path}") + else: + logger.warning("No pretrained checkpoint provided, training from scratch") + + if config['training']: + optimizer = create_optimizer(config, model_without_ddp) + scheduler = create_scheduler(config, optimizer) + scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) + + elif config['generating']: + model.llm.set_input_embeddings(model.text_embedding) + + if config['distributed']: + + static_graph=config.stage!='stage_1' + if len(config.media_train) > 0: + static_graph = False + + model = torch.nn.parallel.DistributedDataParallel( + model_without_ddp, + device_ids=[config['gpu']], + find_unused_parameters=find_unused_parameters, # `False` for image-only task + static_graph=static_graph + ) + + if config['stage'] == 'stage_3': + return ( + model, + model_without_ddp, + optimizer, + scheduler, + scaler, + start_epoch, + global_step, + visdial_step, + avsd_step, + nextqa_step, + config + ) + return ( + model, + model_without_ddp, + optimizer, + scheduler, + scaler, + start_epoch, + global_step, + config + ) + + +def setup_data(config): + logger.info("[INFO] Creating datasets") + + # define the processors + vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res) + + if config['stage'] == 'stage_1': + text_processor = BlipCaptionProcessor(max_words=config.max_cap_len) + + if config['debugging']: + train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val') + else: + train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'train') + + val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val') + + # cc3m_dataset = ConcatDataset([train_datasets['cc3m'], val_datasets['cc3m']]) + + # webvid_dataset = ConcatDataset([train_datasets['webvid'], val_datasets['webvid']]) + + # train_datasets = [cc3m_dataset, webvid_dataset] + train_datasets = list(train_datasets.values()) + val_datasets = list(val_datasets.values()) + + batch_sizes = [config['batch_size_cc3m'], config['batch_size_webvid']] + num_samples = [len(d) for d in train_datasets] + config = set_training_steps(config, num_samples, batch_sizes) + + train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True) + val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True) + + # val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'test') + + # val_dataloader = load_dataloaders(config, val_datasets, 'test', output_dict=True) + + if config['stage'] == 'stage_2': + text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50 + train_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'train')] + val_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'val')] + batch_sizes = [config['batch_size_champagne']] + num_samples = [len(d) for d in train_datasets] + config = set_training_steps(config, num_samples, batch_sizes) + + train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True) + val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True) + + + if config['stage'] == 'stage_3': + text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50 + train_datasets = [] + val_datasets = [] + for medium in config['media_train']: + if medium == 'visdial': + load_dataset_fn = load_visdial_dataset + elif medium == 'avsd': + load_dataset_fn = load_avsd_dataset + elif medium == 'nextqa': + load_dataset_fn = load_nextqa_dataset + # elif medium == 'champagne': + # load_dataset_fn = load_champagne_dataset + + train_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'train')) + + for medium in config['media_val']: + if medium == 'visdial': + load_dataset_fn = load_visdial_dataset + elif medium == 'avsd': + load_dataset_fn = load_avsd_dataset + elif medium == 'nextqa': + load_dataset_fn = load_nextqa_dataset + # elif medium == 'champagne': + # load_dataset_fn = load_champagne_dataset + + val_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'val')) + + batch_sizes = [d.batch_size for d in train_datasets] + num_samples = [len(d) for d in train_datasets] + config = set_training_steps(config, num_samples, batch_sizes) + + train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True) + + val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True) + + return train_dataloaders, val_dataloaders + + +def setup_data_test(config): + vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res) + text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50 + + if config.media_test == 'visdial': + load_dataset_fn = load_visdial_dataset + elif config.media_test == 'avsd': + load_dataset_fn = load_avsd_dataset + elif config.media_test == 'nextqa': + load_dataset_fn = load_nextqa_dataset + test_dataset = load_dataset_fn(config, vis_processor, text_processor, 'test') + + test_dataloader = DataLoader( + test_dataset, shuffle=False, batch_size=test_dataset.batch_size) + + return test_dataloader + + +# def setup_data_test(config, args): +# tokenizer_experts = BertTokenizer.from_pretrained('bert-base-uncased') +# tokenizer_enc_dec = None +# if config.enc_dec_family == 'flan_t5': +# tokenizer_enc_dec = T5Tokenizer.from_pretrained(config.enc_dec_name) +# elif config.enc_dec_family == 'bart': +# tokenizer_enc_dec = BartTokenizer.from_pretrained(config.enc_dec_name) +# if config['tie_embeddings']: +# tokenizer_experts = tokenizer_enc_dec + +# if config['medium'] == 'avsd': +# test_dataset = AVSDDataSet(config, 'avsd', tokenizer_experts, tokenizer_enc_dec, 'test') +# test_dataloader = DataLoader( +# test_dataset, shuffle=False, batch_size=test_dataset.batch_size, collate_fn=test_dataset.collate_fn) +# return test_dataloader diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..5799643 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,266 @@ +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy import interpolate +from typing import List + +logger = logging.getLogger(__name__) + + +class MLM: + def __init__( + self, + mask_token: int, + padding_token: int, + no_mask_tokens: List[int], + n_tokens: int, + masking_prob: float = 0.15, + randomize_prob: float = 0.1, + no_change_prob: float = 0.1 + ): + self.mask_token = mask_token + self.padding_token = padding_token + self.no_mask_tokens = list(set(no_mask_tokens + [padding_token, mask_token])) + self.n_tokens = n_tokens + self.masking_prob = masking_prob + self.randomize_prob = randomize_prob + self.no_change_prob = no_change_prob + + def __call__(self, x: torch.Tensor): + full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob + for tok in self.no_mask_tokens: + full_mask &= x != tok # unmask unwanted tokens --> 0 + + unchanged_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob) + random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob) + random_token_idx = torch.nonzero(random_token_mask, as_tuple=True) + random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device) + mask = full_mask & ~random_token_mask & ~unchanged_mask + + y = x.clone().detach() + x.masked_fill_(mask, self.mask_token) + x[random_token_idx] = random_tokens + y.masked_fill_(~full_mask, self.padding_token) + + return x, y + + + +def _init_transformer_weights(module, initializer_range=0.02): + """Initialize the weights. Copied from transformers ViT/Bert model init""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): + """ + Add/Remove extra temporal_embeddings as needed. + https://arxiv.org/abs/2104.00650 shows adding zero paddings works. + + temp_embed_old: (1, num_frames_old, 1, d) + temp_embed_new: (1, num_frames_new, 1, d) + add_zero: bool, if True, add zero, else, interpolate trained embeddings. + """ + # TODO zero pad + num_frms_new = temp_embed_new.shape[1] + num_frms_old = temp_embed_old.shape[1] + logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") + if num_frms_new > num_frms_old: + if add_zero: + temp_embed_new[ + :, :num_frms_old + ] = temp_embed_old # untrained embeddings are zeros. + else: + temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) + elif num_frms_new < num_frms_old: + temp_embed_new = temp_embed_old[:, :num_frms_new] + else: # = + temp_embed_new = temp_embed_old + return temp_embed_new + + +def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): + """ + temp_embed_old: (1, num_frames_old, 1, d) + Returns: + temp_embed_new: (1, num_frames_new, 1, d) + """ + temp_embed_old = temp_embed_old.squeeze(2).permute( + 0, 2, 1 + ) # (1, d, num_frames_old) + temp_embed_new = F.interpolate( + temp_embed_old, num_frames_new, mode="linear" + ) # (1, d, num_frames_new) + temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( + 2 + ) # (1, num_frames_new, 1, d) + return temp_embed_new + + +def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new): + """ + Args: + pos_embed_old: (1, L_old, d), pre-trained + pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights + num_patches_new: + """ + # interpolate position embedding + embedding_size = pos_embed_old.shape[-1] + num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches_new ** 0.5) + + if orig_size != new_size: + # class_token and dist_token are kept unchanged + # the extra tokens seems always at the beginning of the position embedding + extra_tokens = pos_embed_old[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_old[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2") + return interpolated_pos_embed + else: + return pos_embed_old + + +def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new): + """ + Args: + state_dict_old: loaded state dict + state_dict_new: state dict for model with new image size + patch_shape_new: new model patch_shape + ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py + """ + all_keys = list(state_dict_old.keys()) + for key in all_keys: + if "relative_position_index" in key: + state_dict_old.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = state_dict_old[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = state_dict_new[key].size() + dst_patch_shape = patch_shape_new + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1 + ) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + # logger.info("Position interpolate for %s from %dx%d to %dx%d" % ( + # key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + # logger.info("Original positions = %s" % str(x)) + # logger.info("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)) + .contiguous() + .view(-1, 1) + .to(rel_pos_bias.device) + ) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + state_dict_old[key] = new_rel_pos_bias + return state_dict_old + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*repeat_idx) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) + + +def mask_logits(target, mask): + return target * mask + (1 - mask) * (-1e10) + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + + @staticmethod + def forward(ctx, tensor, args): + output = [torch.empty_like(tensor) for _ in range(args.world_size)] + torch.distributed.all_gather(output, tensor) + ctx.rank = args.rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, dim=0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], + None, + ) + + +allgather_wgrad = AllGather.apply diff --git a/models/v2dial.py b/models/v2dial.py new file mode 100644 index 0000000..be665c2 --- /dev/null +++ b/models/v2dial.py @@ -0,0 +1,2213 @@ +import json +import re +import glog as logging +import random +import os + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +# from minigpt4.common.registry import registry +from .backbones.blip2 import Blip2Base, disabled_train +from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration +from transformers.models.bart.modeling_bart import BartForConditionalGeneration +# from .backbones.encoder_decoder.xflan_t5 import T5ForConditionalGeneration +from .backbones.modeling_mistral import MistralForCausalLM +from .backbones.modeling_llama_v2 import LlamaForCausalLM +from .backbones.moes import MoELayer, Pooler +# from .backbones.moes_huggingface import MoEPooler +# from .backbones.moes_huggingface import MoELayer, MoEPooler +from .modules.temporal_modelling import SpatialAttention, TemporalAttention +from .common.dist_utils import concat_all_gather, all_gather_with_grad +from .utils import MLM +from utils.dist import is_main_process + +# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model +# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model +# from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub + +from transformers import AutoTokenizer, DataCollatorForLanguageModeling +from transformers import BitsAndBytesConfig + +from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings + +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_kbit_training, + set_peft_model_state_dict, +) +import time +import numpy as np + +# from minigpt4.models import policies + +class V2DialAbstract(Blip2Base): + def __init__(self): + super(V2DialAbstract, self).__init__() + + def shift_right(self, input_ids): + decoder_start_token_id = self.llm.config.decoder_start_token_id + pad_token_id = self.llm.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " + "See T5 docs for more information." + ) + + + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + def encode_vis(self, image, device, is_vid=True): + num_frames = image.size(1) + bs_pre_reshape = image.size(0) + if len(image.shape) > 4: + image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + # with self.maybe_autocast(): # inherited from Blip2Base + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + + bs, pn, hs = image_embeds.shape + if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) + image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + vis_embed = self.vit_proj(image_embeds) # project to LLM input size (200,64,5632) -> (200,64, d_hidden) + + # reshape the video features + vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) + + # Perfrom spatial temporal attention + vis_embed_spatial = self.spatial_att(vis_embed) + vis_feat_len = vis_embed_spatial.size(1) + + if not self.config.embed_from_llm: + vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) + vis_spatial_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + vis_embed_temporal, vis_temporal_mask = None, None + + if is_vid: + vis_embed_temporal = self.temporal_att(vis_embed) + if not self.config.embed_from_llm: + vis_embed_temporal = vis_embed_temporal + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) + vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + return vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask + + def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None): + if max_len: + text_tokenized = self.tokenizer( + text, + return_tensors='pt', + padding='max_length', + max_length=max_len, + truncation=True, + add_special_tokens=False, + return_special_tokens_mask=True + ).to(device) + else: + text_tokenized = self.tokenizer( + text, + return_tensors='pt', + padding='longest', + add_special_tokens=False, + return_special_tokens_mask=True + ).to(device) + + text_ids = text_tokenized.input_ids + text_attention_mask = text_tokenized.attention_mask + + if add_bos: + bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device) + bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) + + text_ids = torch.cat([bos_ids, text_ids], dim=1) + text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1) + + if add_eos: + eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) + eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) + + text_ids = torch.cat([text_ids, eos_ids], dim=1) + text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1) + + + return text_ids, text_attention_mask + + def get_extended_attention_mask(self, attention_mask=None): + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + return extended_attention_mask + + @staticmethod + def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + + +class V2DialBase(V2DialAbstract): + def __init__(self, config): + super(V2DialBase, self).__init__() + self.config = config + + ################## 1. Select Tokenizer -- We use BERT tokenizer ################## + bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size)) + + tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size)) + + text_embedding = BertEmbeddings(bert_config) + text_embedding.apply(self.init_weights) + + token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modality types (temp/spa/text) + token_type_embedding.apply(self.init_weights) + + # Define the masking strategy + mlm_collactor = DataCollatorForLanguageModeling( + tokenizer, mlm=True, mlm_probability=config.masking_prob, return_tensors='pt') + + ################## 2. Select the backbone ViT ################## + logging.info('[INFO] Loading ViT in progress') + if config.freeze_vit: + # vit_precision = 'fp16' if config.fp16 else 'fp32' + logging.info(f'[INFO] ViT precision: {config.vit_precision}') + visual_encoder, ln_vision = self.init_vision_encoder( + config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision + ) + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + for name, param in ln_vision.named_parameters(): + param.requires_grad = False + ln_vision = ln_vision.eval() + ln_vision.train = disabled_train + logging.info('[INFO] ViT frozen') + + else: + vit_precision = 'fp32' + visual_encoder, ln_vision = self.init_vision_encoder( + config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision + ) + logging.info('[INFO] ViT hot') + logging.info('[INFO] ViT successfully loaded') + + ################## 3. Define the ViT-Expert communication Interface ################## + self.system_prompt = False + self.vit_token_pooling = config.vit_token_pooling + if self.vit_token_pooling: + vit_proj = nn.Linear( + 1408*4, bert_config.hidden_size + ) + else: + vit_proj = nn.Linear( + 1408, bert_config.hidden_size + ) + vit_proj.apply(self.init_weights) + + spatial_att = SpatialAttention(input_dim=bert_config.hidden_size) + temporal_att = TemporalAttention(input_dim=bert_config.hidden_size) + + spatial_att.apply(self.init_weights) + temporal_att.apply(self.init_weights) + + ################## 4. Define the Expert layers ################## + moe_layers = [] + + for moe_layer_idx in range(config.num_moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + moe_layer = MoELayer( + bert_config.hidden_size, + bert_config.num_attention_heads, + expert_flag, + use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts + ) + moe_layer.apply(self.init_weights) + moe_layers.append(moe_layer) + + logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded') + + moe_layers = nn.ModuleList(moe_layers) + moe_norm = nn.LayerNorm(bert_config.hidden_size) + + ################## 5. Define the projection layers for contrastive learning ################## + temp_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + spatial_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + vision_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + cap_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + + temp_proj.apply(self.init_weights) + spatial_proj.apply(self.init_weights) + vision_proj.apply(self.init_weights) + cap_proj.apply(self.init_weights) + + ################## 6. Define the pooler for matching loss ################## + pooler = Pooler(bert_config.hidden_size) + pooler.apply(self.init_weights) + + ################## 5. Attach the matching heads ################## + stm_head = nn.Linear(bert_config.hidden_size, 2) + vcm_head = nn.Linear(bert_config.hidden_size, 2) + lm_head = nn.Linear(bert_config.hidden_size, len(tokenizer)) + + stm_head.apply(self.init_weights) + vcm_head.apply(self.init_weights) + lm_head.apply(self.init_weights) + + temp = nn.Parameter(0.07 * torch.ones([])) + # temp = 0.07 + + # Attach the components to self + self.tokenizer = tokenizer + self.mlm_collactor = mlm_collactor + self.text_embedding = text_embedding + self.token_type_embedding = token_type_embedding + self.visual_encoder = visual_encoder + self.ln_vision = ln_vision + self.vit_proj = vit_proj + self.moe_layers = moe_layers + self.moe_norm = moe_norm + self.spatial_att = spatial_att + self.temporal_att = temporal_att + self.temp_proj = temp_proj + self.spatial_proj = spatial_proj + self.vision_proj = vision_proj + self.cap_proj = cap_proj + self.pooler = pooler + self.stm_head = stm_head + self.vcm_head = vcm_head + self.lm_head = lm_head + self.temp = temp + + @staticmethod + def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def build_query_embeds(self, num_query_tokens, dim_query_tokens): + query_embeds = nn.Parameter( + torch.zeros(1, num_query_tokens, dim_query_tokens) + ) + query_embeds.data.normal_(mean=0.0, std=0.02) + return query_embeds + + def encode_caption(self, cap): + cap_output = self.cap_expert( + input_ids=cap.input_ids, + attention_mask=cap.attention_mask, + return_dict=True, + ) + cap_embeds = cap_output.last_hidden_state + pooled_cap_embeds = cap_embeds[:, 0] + return cap_embeds, pooled_cap_embeds + + def encode_vis_old(self, vis, media_type): + # if media_type == 'webvid': + # bs, num_frames, c, h, w = vis.size() + # # reshape + # vis = vis.view(bs * num_frames, c, h, w) + vis_embed = self.beit(vis).last_hidden_state + # vis_embed = self.beit_layernorm(vis_output.last_hidden_state) + # remove cls token embedding + vis_embed = vis_embed[:, :, 1:, :] + vis_embed = self.beit_lin(vis_embed) + # perform spatial attention + vis_spatial_embed = self.spatial_att(vis_embed) + vis_temp_embed = self.tempotal_att(vis_embed) if media_type in ['webvid', 'msrvtt', 'champagne', 'avsd'] else None + + return vis_spatial_embed, vis_temp_embed + + def encode_queries(self, query_embeds, vis_embeds, vis_mode): + if vis_mode == 'spatial': + expert = self.spatial_expert + layer_norm = self.spatial_layernorm + elif vis_mode == 'temporal': + expert = self.temporal_expert + layer_norm = self.temporal_layernorm + else: + raise ValueError(f'[ERROR] {vis_mode} not implemented!') + + attention_mask = torch.ones( + query_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device) + + vis_attention_mask = torch.ones( + vis_embeds.size()[:-1], dtype=torch.long).to(vis_embeds.device) + + if self.config['expert_layer_type'] == 'bert': + + output_dict = expert( + encoder_embeds=query_embeds, + encoder_hidden_states=vis_embeds, + encoder_attention_mask=vis_attention_mask, + ) + query_embeds = layer_norm(output_dict.last_hidden_state) + pooled_query_embeds = output_dict.pooler_output + + elif self.config['expert_layer_type'] == 'bart': + output_dict = expert( + inputs_embeds=query_embeds, + attention_mask=attention_mask, + cross_embeds=vis_embeds, + cross_attention_mask=vis_attention_mask, + ) + + query_embeds = layer_norm(output_dict.last_hidden_state) + pooled_query_embeds = query_embeds[:, 0] + + return query_embeds, pooled_query_embeds + + # def encode_vis(self, image, device, is_vid=True): + # num_frames = image.size(1) + # bs_pre_reshape = image.size(0) + # if len(image.shape) > 4: + # image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + # # with self.maybe_autocast(): # inherited from Blip2Base + # image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + # image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + + # bs, pn, hs = image_embeds.shape + # if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) + # image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + # vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) + + # # reshape the video features + # vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) + + + # # Perfrom spatial temporal attention + # vis_embed_spatial = self.spatial_att(vis_embed) + # vis_feat_len = vis_embed_spatial.size(1) + + # vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) + # vis_spatial_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + # vis_embed_temporal, vis_temporal_mask = None, None + + # if is_vid: + # vis_embed_temporal = self.temporal_att(vis_embed) + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) + # vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + # return vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask + + def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True): + num_frames = image.size(1) + bs_pre_reshape = image.size(0) + if len(image.shape) > 4: + image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + # with self.maybe_autocast(): # inherited from Blip2Base + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + + bs, pn, hs = image_embeds.shape + if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) + image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) + + # reshape the video features + vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) + size_orig = vis_embed.size() + + # Perfrom spatial temporal attention + vis_embed = self.spatial_att(vis_embed) + if is_vid: + vis_embed = vis_embed.view(size_orig) + vis_embed = self.temporal_att(vis_embed) + + vis_feat_len = vis_embed.size(1) + + vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) + vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + return vis_embed, vis_mask + + def tokenize_text(self, text, device, add_bos=False, add_eos=False, max_len=None): + if max_len: + text_tokenized = self.tokenizer( + text, + return_tensors='pt', + padding='max_length', + max_length=max_len, + truncation=True, + add_special_tokens=False, + return_special_tokens_mask=True + ).to(device) + else: + text_tokenized = self.tokenizer( + text, + return_tensors='pt', + padding='longest', + add_special_tokens=False, + return_special_tokens_mask=True + ).to(device) + + text_ids = text_tokenized.input_ids + text_attention_mask = text_tokenized.attention_mask + + if add_bos: + bos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.bos_token_id).to(device) + bos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) + + text_ids = torch.cat([bos_ids, text_ids], dim=1) + text_attention_mask = torch.cat([bos_att, text_attention_mask], dim=1) + + if add_eos: + eos_ids = torch.LongTensor(text_ids.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) + eos_att = torch.LongTensor(text_ids.size(0), 1).fill_(1).to(device) + + text_ids = torch.cat([text_ids, eos_ids], dim=1) + text_attention_mask = torch.cat([text_attention_mask, eos_att], dim=1) + + + return text_ids, text_attention_mask + + def encode_text(self, text, max_len, device): + text_tokenized = self.tokenizer( + text, + return_tensors='pt', + padding='max_length', + max_length=max_len, + truncation=True, + add_special_tokens=False + ).to(device) + text_ids = text_tokenized.input_ids + text_embeds = self.embed(text_ids) + text_attention_mask = text_tokenized.attention_mask + return text_embeds, text_ids, text_attention_mask + + def encode_spatial_toks(self, batch_size, device): + # ['', '', '', '', ''] + + special_toks_ids = self.tokenizer( + '', + return_tensors='pt', + padding='longest', + truncation=True, + add_special_tokens=False + ).to(device) + + special_toks_embeds = self.embed(special_toks_ids.input_ids) + special_toks_embeds = special_toks_embeds.repeat(batch_size, 1, 1) + return special_toks_embeds + + def construt_input_embeds_stage_1(self, vis_embed, cap_embed, special_toks_embeds, cap_attention_mask, media_type, device): + batch_size = vis_embed.size(0) + embed_dim = vis_embed.size(-1) + vis_embed = vis_embed.view(batch_size, -1, embed_dim) + + input_embeds = [] + input_attention_mask = [] + special_toks_indices = { + '': 0, + '': 1, + '': 2, + } + # special_toks_embeds = + # for video: [spatial_featurres][temporal_featurres][caption_features] + # for image: [spatial_featurres][caption_features] + + input_embeds.append(special_toks_embeds[:, 0:3, :]) # + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + + input_embeds.append(vis_embed.clone()) # [spatial_features] + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + + if media_type == 'webvid': + # here we copy the original vis_embeds twice and will apply spatial and temporal attention later + input_embeds.append(special_toks_embeds[:, 3:4, :]) # + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 + + input_embeds.append(vis_embed.clone()) # [temporal_features] + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + + + input_embeds.append(special_toks_embeds[:, 4:5, :]) # + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + + if media_type == 'webvid': + special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 + elif media_type == 'cc3m': + special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 + + input_embeds.append(cap_embed) # [caption_features] + input_attention_mask.append(cap_attention_mask) + + input_embeds.append(special_toks_embeds[:, 6:7, :]) # + input_attention_mask.append(torch.ones(input_embeds[-1].size()[:-1], dtype=torch.long).to(device)) + special_toks_indices[''] = special_toks_indices[''] + input_embeds[-2].size(1) + 1 + + input_embeds = torch.cat(input_embeds, dim=1) + input_attention_mask = torch.cat(input_attention_mask, dim=1) + assert input_embeds.size()[:-1] == input_attention_mask.size() + + return input_embeds, input_attention_mask, special_toks_indices + + def construct_global_input(self, cap_ids, cap_attention_mask, vid_feat_len, media_type, device): + # for video: [spatial_featurres][temporal_features][caption_features] + # for image: [spatial_featurres][caption_features] + batch_size = cap_ids.size(0) + special_toks_indices = { + '': 0, + '': 1, + '': 2, + } + + ids = [self.added_vocab['']] + [self.added_vocab['']] + [self.added_vocab['']] + ids += vid_feat_len * [self.added_vocab['']] + if media_type == 'webvid': + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + ids += vid_feat_len * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + ids += cap_ids.size(1) * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + total_len = len(ids) + + ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + + ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_ids + + mask = torch.ones((batch_size, total_len), device=device) + mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_attention_mask + + return ids, mask, special_toks_indices + + def compute_contrastive_loss(self, x, y_all, y, x_all): + sim_x2y = torch.mm(x, y_all.t()) # (bs, bs*ngpus) + sim_x2y = sim_x2y / self.temp + + sim_y2x = torch.mm(y, x_all.t()) # (bs, bs*ngpus) + sim_y2x = sim_y2x / self.temp + + rank = dist.get_rank() if self.config['distributed'] else 0 + + bs = x.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + x.device + ) + loss_contrastive = ( + F.cross_entropy(sim_x2y, targets, label_smoothing=0.1) + + F.cross_entropy(sim_y2x, targets, label_smoothing=0.1) + ) / 2 + + return loss_contrastive, sim_x2y, sim_y2x + + def get_extended_attention_mask(self, attention_mask=None): + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + return extended_attention_mask + + def shared_forward( + self, + vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask, + cap_ids, cap_mask, is_vid, device): + + # is_vid = media_type == 'webvid' + # batch_size = len(cap) + vis_feat_len = vis_spatial.size(1) + input_embeds = [] + input_masks = [] + + input_embeds.append(vis_spatial) + input_masks.append(vis_spatial_mask) + + if is_vid: + input_embeds.append(vis_temporal) + input_masks.append(vis_temporal_mask) + + cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) + cap_feat_len = cap_embeds.size(1) + + input_embeds.append(cap_embeds) + input_masks.append(cap_mask) + + input_embeds = torch.cat(input_embeds, dim=1) + input_masks = torch.cat(input_masks, dim=1) + + # expand the mask + input_masks = self.get_extended_attention_mask(attention_mask=input_masks) + + # MoEs feed-forward + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + + input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks) + + #TODO normalize the output () !!!!!! + input_embeds = self.moe_norm(input_embeds) + + # return the features + spatial_feats = input_embeds[:, :vis_feat_len] + temporal_feats = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None + cap_feats = input_embeds[:, -cap_feat_len:] + cls_feats = self.pooler(cap_feats) + + moe_outputs = { + 'spatial_feats': spatial_feats, + 'temporal_feats': temporal_feats, + 'cap_feats': cap_feats, + 'cls_feats': cls_feats, + } + + return moe_outputs + + def shared_forward_no_sep_spatial_temporal_experts( + self, + vis, vis_mask, + cap_ids, cap_mask, is_vid, device): + + # is_vid = media_type == 'webvid' + # batch_size = len(cap) + vis_feat_len = vis.size(1) + input_embeds = [] + input_masks = [] + + input_embeds.append(vis) + input_masks.append(vis_mask) + + # if is_vid: + # input_embeds.append(vis_temporal) + # input_masks.append(vis_temporal_mask) + + cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) + cap_feat_len = cap_embeds.size(1) + + input_embeds.append(cap_embeds) + input_masks.append(cap_mask) + + input_embeds = torch.cat(input_embeds, dim=1) + input_masks = torch.cat(input_masks, dim=1) + + # expand the mask + input_masks = self.get_extended_attention_mask(attention_mask=input_masks) + + # MoEs feed-forward + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + + input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, is_vid=is_vid, mask=input_masks) + + #TODO normalize the output () !!!!!! + input_embeds = self.moe_norm(input_embeds) + + # return the features + vis_feats = input_embeds[:, :vis_feat_len] + cap_feats = input_embeds[:, -cap_feat_len:] + cls_feats = self.pooler(cap_feats) + + moe_outputs = { + 'vis_feats': vis_feats, + 'cap_feats': cap_feats, + 'cls_feats': cls_feats, + } + + return moe_outputs + + def vcm_iteration(self, vis, cap, neg_vis, is_vid, device): + # Prepare the vis data + # is_vid = media_type == 'webvid' + num_positive_samples = len(cap) // 2 + num_negative_samples = len(cap) - num_positive_samples + + vcm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device) + vcm_labels = vcm_labels[torch.randperm(vcm_labels.size(0))].long() + + # now get the mixed vis data + + vis_mixed = [p if vcm_labels[i] == 1 else n for i, (p, n) in enumerate(zip(vis, neg_vis))] + vis_mixed = torch.stack(vis_mixed, dim=0) + + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + + if self.config.use_sep_spatial_temp_experts: + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis_mixed, device, is_vid=is_vid) + moe_outputs = self.shared_forward( + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) + else: + vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( + vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) + + vcm_logits = self.vcm_head(moe_outputs['cls_feats']) + loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) + return loss_vcm + + def stm_iteration(self, vis, cap, neg_vis, is_vid, device): + num_positive_samples = len(cap) // 2 + num_negative_samples = len(cap) - num_positive_samples + + stm_labels = torch.cat([torch.ones(num_positive_samples), torch.zeros(num_negative_samples)]).to(device) + stm_labels = stm_labels[torch.randperm(stm_labels.size(0))].long() + + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) + neg_vis_embed_spatial, _ , neg_vis_embed_temporal, _ = self.encode_vis(neg_vis, device, is_vid=is_vid) + + # now get the mixed vis data + vis_embed_spatial_mixed = [] + vis_embed_temporal_mixed = [] + + for i, (pos_spatial, pos_temporal, neg_spatial, neg_temporal) in enumerate( + zip(vis_embed_spatial, vis_embed_temporal, S, neg_vis_embed_temporal)): + if stm_labels[i] == 1: + vis_embed_spatial_mixed.append(pos_spatial) + vis_embed_temporal_mixed.append(pos_temporal) + else: + # 50% negative spatial / 50% negative temporal + if torch.rand(1).item() < 0.5: + vis_embed_spatial_mixed.append(pos_spatial) + vis_embed_temporal_mixed.append(neg_temporal) + else: + vis_embed_spatial_mixed.append(neg_spatial) + vis_embed_temporal_mixed.append(pos_temporal) + + vis_embed_spatial_mixed = torch.stack(vis_embed_spatial_mixed, dim=0) + vis_embed_temporal_mixed = torch.stack(vis_embed_temporal_mixed, dim=0) + + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + + moe_outputs = self.shared_forward( + vis_embed_spatial_mixed, vis_spatial_mask, vis_embed_temporal_mixed, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) + + stm_logits = self.vcm_head(moe_outputs['cls_feats']) + loss_stm = F.cross_entropy(stm_logits, stm_labels) + return loss_stm + + def mlm_iteration(self, vis, cap, is_vid, device): + if self.config.use_sep_spatial_temp_experts: + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) + else: + vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + cap_ids = cap_ids.tolist() + + # NOTE We make sure to mask some tokens here to avoid nan loss later + mlm_output = self.mlm_collactor(cap_ids) + cap_ids = mlm_output['input_ids'].to(device) + labels_cap = mlm_output['labels'].to(device) + + if self.config.use_sep_spatial_temp_experts: + moe_outputs = self.shared_forward( + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) + else: + moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( + vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) + + mlm_logits = self.lm_head(moe_outputs['cap_feats']) + loss_mlm = F.cross_entropy(mlm_logits.view(-1, mlm_logits.size(-1)), labels_cap.view(-1)) + return loss_mlm + + def vcc_iteration(self, vis, cap, is_vid, device): + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + + if self.config.use_sep_spatial_temp_experts: + moe_outputs = self.shared_forward( + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) + vis_feats = moe_outputs['spatial_feats'] + if is_vid: + vis_feats = torch.cat([moe_outputs['spatial_feats'], moe_outputs['temporal_feats']], dim=1) + else: + vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + moe_outputs = self.shared_forward_no_sep_spatial_temporal_experts( + vis_embed, vis_mask, cap_ids, cap_mask, is_vid, device) + vis_feats = moe_outputs['vis_feats'] + + cap_feats = F.normalize(self.cap_proj(moe_outputs['cls_feats']), dim=-1) + vis_feats = F.normalize(self.vision_proj(vis_feats), dim=-1) + + vis_feats_all = concat_all_gather(vis_feats) + cap_feats_all = concat_all_gather(cap_feats) + + sim_v2c = torch.matmul( + vis_feats.unsqueeze(1), cap_feats_all.unsqueeze(-1) + ).squeeze() + + sim_v2c, _ = sim_v2c.max(-1) + sim_v2c = sim_v2c / self.temp + + sim_c2v = torch.matmul( + cap_feats.unsqueeze(1).unsqueeze(1), vis_feats_all.permute(0, 2, 1) + ).squeeze() + + sim_c2v, _ = sim_c2v.max(-1) + sim_c2v = sim_c2v / self.temp + + rank = dist.get_rank() if self.config['distributed'] else 0 + + bs = vis_feats.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + device + ) + loss_vcc = ( + F.cross_entropy(sim_v2c, targets, label_smoothing=0.1) + + F.cross_entropy(sim_c2v, targets, label_smoothing=0.1) + ) / 2 + return loss_vcc + + def stc_iteration(self, vis, cap, is_vid, device): + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + moe_outputs = self.shared_forward( + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask, cap_ids, cap_mask, is_vid, device) + + spatial_feats = F.normalize(self.spatial_proj(moe_outputs['spatial_feats']), dim=-1) + temporal_feats = F.normalize(self.temp_proj(moe_outputs['temporal_feats']), dim=-1) + + spatial_feats_all = concat_all_gather(spatial_feats) + temporal_feats_all = concat_all_gather(temporal_feats) + + sim_s2t = torch.matmul( + spatial_feats.unsqueeze(1), temporal_feats_all + ) + + sim_s2t, _ = sim_s2t.max(-1) + sim_s2t, _ = sim_s2t.max(-1) + sim_s2t = sim_s2t / self.temp + + sim_t2s = torch.matmul( + temporal_feats.unsqueeze(1), spatial_feats_all + ) + + sim_t2s, _ = sim_t2s.max(-1) + sim_t2s, _ = sim_t2s.max(-1) + sim_t2s = sim_t2s / self.temp + + rank = dist.get_rank() if self.config['distributed'] else 0 + bs = vis.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + device + ) + loss_stc = ( + F.cross_entropy(sim_s2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2s, targets, label_smoothing=0.1) + ) / 2 + return loss_stc + + + def forward(self, vis, cap, neg_vis, media_type): + device = vis.device + is_vid = media_type == 'webvid' + loss_stc = torch.tensor(0).to(device) + loss_stm = torch.tensor(0).to(device) + loss_vcc = torch.tensor(0).to(device) + loss_vcm = torch.tensor(0).to(device) + loss_mlm = torch.tensor(0).to(device) + + if self.config.loss_dict['vcm'] != 0: + loss_vcm = self.vcm_iteration(vis, cap, neg_vis, is_vid, device) + + if self.config.loss_dict['vcc'] != 0: + loss_vcc = self.vcc_iteration(vis, cap, is_vid, device) + + if self.config.loss_dict['stm'] != 0 and is_vid: + loss_stm = self.stm_iteration(vis, cap, neg_vis, is_vid, device) + + if self.config.loss_dict['stc'] != 0 and is_vid: + loss_stc = self.stc_iteration(vis, cap, is_vid, device) + + if self.config.loss_dict['mlm'] != 0: + loss_mlm = self.mlm_iteration(vis, cap, is_vid, device) + + return dict( + loss_stc = loss_stc * self.config.loss_dict['stc'], + loss_stm = loss_stm * self.config.loss_dict['stm'], + loss_vcc = loss_vcc * self.config.loss_dict['vcc'], + loss_vcm = loss_vcm * self.config.loss_dict['vcm'], + loss_mlm = loss_mlm * self.config.loss_dict['mlm'], + ) + + def forward__(self, vis, cap, neg_vis, media_type): + + device = vis.device + self.vcm_matching(vis, cap, neg_vis, media_type, device) + self.shared_forward(vis, cap, media_type, device) + + + # First init all losses to zeros + loss_stc = torch.tensor(0).to(device) + loss_stm = torch.tensor(0).to(device) + loss_vcc = torch.tensor(0).to(device) + loss_vcm = torch.tensor(0).to(device) + loss_mlm = torch.tensor(0).to(device) + + batch_size = len(cap) + # First get the visual features depending on the media type + vis_embed = self.encode_vis(vis) + neg_vis_embed = self.encode_vis(neg_vis) + + embed_dim = vis_embed.size(-1) + num_frames = vis.size(1) + # reshape the video features + vis_embed = vis_embed.view(batch_size, num_frames, -1, embed_dim) + neg_vis_embed = neg_vis_embed.view(batch_size, num_frames, -1, embed_dim) + + # Perfrom spatial temporal attention and reshape + vis_embed_spatial = self.spatial_att(vis_embed) + # vis_embed_spatial = vis_embed_spatial.view(batch_size, -1, embed_dim) + + neg_vis_embed_spatial = self.spatial_att(neg_vis_embed) + # neg_vis_embed_spatial = neg_vis_embed_spatial.view(batch_size, -1, embed_dim) + + if media_type == 'webvid': + vis_embed_temporal = self.temporal_att(vis_embed) + # vis_embed_temporal = vis_embed_temporal.view(batch_size, -1, embed_dim) + + neg_vis_embed_temporal = self.temporal_att(neg_vis_embed) + # neg_vis_embed_temporal = neg_vis_embed_temporal.view(batch_size, -1, embed_dim) + + spatial_feat_len = vis_embed_spatial.size(1) + + # construct the global input tensor --> use place holder for vis features + cap_ids, cap_attention_mask = self.tokenize_text(cap, device, max_len=self.config.max_cap_len) + input_ids, input_mask, special_toks_indices = self.construct_global_input(cap_ids, cap_attention_mask, spatial_feat_len, media_type, device) + + input_embeds = self.embed(input_ids) + + if media_type == 'webvid': + input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_spatial + input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_temporal + + elif media_type == 'cc3m': + input_embeds[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = vis_embed_spatial + + # LLM --> MoEs + input_embeds = self.moe_llm_bottleneck(input_embeds) + input_embeds_orig = input_embeds.clone() + + neg_vis_embed_spatial = self.moe_llm_bottleneck(neg_vis_embed_spatial) + + if media_type == 'webvid': + neg_vis_embed_temporal = self.moe_llm_bottleneck(neg_vis_embed_temporal) + + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + + input_embeds = moe_layer(input_embeds, special_toks_indices, expert_flag, mask=input_mask) + + #TODO normalize the output () !!!!!! + + #-------------------- Contrastive losses --------------------# + cap_proj_feats = F.normalize(self.cap_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) + vis_proj_feats = F.normalize(self.vision_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) + if media_type == 'webvid': + spatial_proj_feats = F.normalize(self.spatial_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) + temp_proj_feats = F.normalize(self.temp_proj(input_embeds[:, special_toks_indices[''], :]), dim=-1) # (bs*gpus, H) + + if self.config.loss_dict['vcc'] != 0: + vis_proj_feats_all = concat_all_gather(vis_proj_feats) # (bs*gpus, H) + cap_proj_feats_all = concat_all_gather(cap_proj_feats) # (bs*gpus, H) + + loss_vcc, _, _ = self.compute_contrastive_loss(vis_proj_feats, cap_proj_feats_all, cap_proj_feats, vis_proj_feats_all) + + # 1- Spatial-Temporal + if media_type == 'webvid': + if self.config.loss_dict['stc'] != 0: + spatial_proj_feats_all = concat_all_gather(spatial_proj_feats) # (bs*gpus, H) + temp_proj_feats_all = concat_all_gather(temp_proj_feats) # (bs*gpus, H) + loss_stc, _, _ = self.compute_contrastive_loss(temp_proj_feats, spatial_proj_feats_all, spatial_proj_feats, temp_proj_feats_all) + + + #-------------------- Matching losses --------------------# + if self.config.loss_dict['vcm'] != 0: + # Negative caption with positive visual + neg_cap_ids, neg_cap_attention_mask, = self.tokenize_text(neg_cap, device, max_len=self.config.max_cap_len) + neg_cap_embed = self.moe_llm_bottleneck(self.embed(neg_cap_ids)) + input_embeds_neg_cap = input_embeds_orig.clone().detach() + input_embeds_neg_cap[:, special_toks_indices[''] + 1:special_toks_indices['']] = neg_cap_embed + input_mask_neg_cap = input_mask.clone().detach() + input_mask_neg_cap[:, special_toks_indices[''] + 1:special_toks_indices['']] = neg_cap_attention_mask + + # Negative visual with positive caption + input_embeds_neg_vis = input_embeds_orig.clone().detach() + input_mask_neg_vis = input_mask.clone().detach() + + # neg_vis_embed = self.encode_vis(neg_vis) + + # # reshape video features + # neg_vis_embed = neg_vis_embed.reshape(batch_size, num_frames, -1, embed_dim) + + # # Perfrom spatial temporal attention and reshape + # neg_vis_embed_spatial = self.spatial_att(neg_vis_embed) + # neg_vis_embed_spatial = neg_vis_embed_spatial.reshape(batch_size, -1, embed_dim) + if media_type == 'webvid': + # neg_vis_embed_temporal = self.temporal_att(neg_vis_embed) + # neg_vis_embed_temporal = neg_vis_embed_temporal.reshape(batch_size, -1, embed_dim) + + input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial + input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_temporal + + elif media_type == 'cc3m': + # neg_vis_embed_spatial = self.moe_llm_bottleneck(neg_vis_embed_spatial) + input_embeds_neg_vis[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial + + # Construct the input of VCM + final_input_embeds_vcm = torch.cat([input_embeds_orig, input_embeds_neg_cap, input_embeds_neg_vis], dim=0) + final_input_mask_vcm = torch.cat([input_mask, input_mask_neg_cap, input_mask_neg_vis], dim=0) + + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + final_input_embeds_vcm = moe_layer(final_input_embeds_vcm, special_toks_indices, expert_flag, mask=final_input_mask_vcm) + + pooled_caption = self.caption_pooler(final_input_embeds_vcm, special_toks_indices['']) + pooled_vis = self.vis_pooler(final_input_embeds_vcm, special_toks_indices['']) + + vcm_feats = torch.mul(pooled_caption, pooled_vis) + vcm_logits = self.vcm_head(vcm_feats) + vcm_labels = torch.cat( + [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], + dim=0, + ).to(device) + + # random permutation of the logits and labels --> make the task not trivial to learn + # perm_idx = torch.randperm(vcm_logits.size(0), device=device) + # perm_idx_extended = perm_idx.unsqueeze(-1).repeat(1, vcm_logits.size(-1)) + + # # Shuffle + # vcm_logits = vcm_logits.scatter(0, perm_idx_extended, vcm_logits) + # vcm_labels = vcm_labels.scatter(0, perm_idx, vcm_labels) + + # class_weight = torch.FloatTensor([1.0, 1.0/3]).to(device) + + loss_vcm = F.cross_entropy(vcm_logits, vcm_labels) # , weight=class_weight) + + if media_type == 'webvid': + if self.config.loss_dict['stm'] != 0: + # Negative spatial with positive temporal + input_embeds_neg_spatial = input_embeds_orig.clone().detach() + input_mask_neg_spatial = input_mask.clone().detach() + input_embeds_neg_spatial[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_spatial + + # Positive spatial with negative temporal + input_embeds_neg_temporal = input_embeds_orig.clone().detach() + input_mask_neg_temporal = input_mask.clone().detach() + input_embeds_neg_temporal[:, special_toks_indices[''] + 1: special_toks_indices[''], :] = neg_vis_embed_temporal + + # Construct the input of STM + final_input_embeds_stm = torch.cat([input_embeds_orig, input_embeds_neg_spatial, input_embeds_neg_temporal], dim=0) + final_input_mask_stm = torch.cat([input_mask, input_mask_neg_spatial, input_mask_neg_temporal], dim=0) + + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + final_input_embeds_stm = moe_layer(final_input_embeds_stm, special_toks_indices, expert_flag, mask=final_input_mask_stm) + + pooled_spatial = self.spatial_pooler(final_input_embeds_stm, special_toks_indices['']) + pooled_temporal = self.temporal_pooler(final_input_embeds_stm, special_toks_indices['']) + + stm_feats = torch.mul(pooled_spatial, pooled_temporal) + stm_logits = self.stm_head(stm_feats) + stm_labels = torch.cat( + [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], + dim=0, + ).to(device) + + # random permutation of the logits and labels --> make the task not trivial to learn + # perm_idx = torch.randperm(stm_logits.size(0), device=device) + # perm_idx_extended = perm_idx.unsqueeze(-1).repeat(1, stm_logits.size(-1)) + + # # Shuffle + # stm_logits = stm_logits.scatter(0, perm_idx_extended, stm_logits) + # stm_labels = stm_labels.scatter(0, perm_idx, stm_labels) + + # class_weight = torch.FloatTensor([1.0, 1.0/3]).to(device) + loss_stm = F.cross_entropy(stm_logits, stm_labels) # , weight=class_weight) + + if self.config.loss_dict['mlm'] != 0: + masked_cap_ids, labels = self.mlm(cap_ids.clone()) + masked_cap_embeds = self.moe_llm_bottleneck(self.embed(masked_cap_ids)) + # inject the masked embeddings instead of the original ones + # input_embeds_mlm[:, special_toks_indices['']+1 : special_toks_indices[''], :] = masked_cap_embeds + + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + masked_cap_embeds = moe_layer(masked_cap_embeds, special_toks_indices, expert_flag, mask=cap_attention_mask, only_text=True) + + # extract the caption last hidden states + # masked_cap_embeds_last = input_embeds_mlm[:, special_toks_indices['']+1 : special_toks_indices[''], :] + lm_logits = self.lm_head(masked_cap_embeds) + loss_mlm = F.cross_entropy( + lm_logits.view(-1, len(self.tokenizer)), + labels.view(-1), + ignore_index=self.mlm.padding_token + ) + + return dict( + loss_stc = loss_stc * self.config.loss_dict['stc'], + loss_stm = loss_stm * self.config.loss_dict['stm'], + loss_vcc = loss_vcc * self.config.loss_dict['vcc'], + loss_vcm = loss_vcm * self.config.loss_dict['vcm'], + loss_mlm = loss_mlm * self.config.loss_dict['mlm'], + ) + + + def get_vis_enc_for_eval(self, vis, media_type): + # First get the visual features depending on the media type + vis_spatial_embed, vis_temporal_embed = self.encode_vis(vis, media_type) + + # Expand the query tokens + spatial_query_embeds = self.spatial_query_embeds.expand(vis_spatial_embed.size(0), -1, -1) + + # Run the spatial expert + spatial_query_embeds, pooled_spatial_query_embeds = self.encode_queries( + spatial_query_embeds, vis_spatial_embed, vis_mode='spatial') + + temporal_query_embeds = self.spatial_query_embeds.expand(vis_temporal_embed.size(0), -1, -1) + temporal_query_embeds, pooled_temporal_query_embeds = self.encode_queries( + temporal_query_embeds, vis_temporal_embed, vis_mode='temporal') + + vis_pooled = torch.cat((pooled_spatial_query_embeds, pooled_temporal_query_embeds), dim=1) + vis_embeds = torch.cat((spatial_query_embeds, temporal_query_embeds), dim=1) + + return vis_embeds, vis_pooled + + def get_expert_encoder(self, expert): + """get text encoder, used for text and cross-modal encoding""" + encoder = None + if expert == 'cap': + encoder = self.cap_expert + if expert == 'spatial': + encoder = self.spatial_expert + if expert == 'temporal': + encoder = self.temporal_expert + if expert == 'sap_att_grounding': + encoder = self.spa_temp_grounding_expert + if expert == 'vis_cap_grounding': + encoder = self.vis_cap_grounding_expert + assert encoder is not None + return encoder.bert if hasattr(encoder, "bert") else encoder + + + +class V2Dial(V2DialAbstract): + def __init__(self, config): + super(V2Dial, self).__init__() + self.config = config + + ################## 1. Select Tokenizer -- We use BERT tokenizer ################## + bert_config = BertConfig.from_pretrained('bert-{}-uncased'.format(config.expert_size)) + tokenizer = AutoTokenizer.from_pretrained('bert-{}-uncased'.format(config.expert_size)) + + text_embedding = BertEmbeddings(bert_config) + text_embedding.apply(self.init_weights) + + token_type_embedding = nn.Embedding(3, bert_config.hidden_size) # Number of modalities (temp/spa/cap/hist-ques-ans) + token_type_embedding.apply(self.init_weights) + + ################## 1. Select LLM -- We use BERT tokenizer ################## + if config.llm_family == 'llama': + logging.info('[INFO] LLM: LLAMA v2') + llm_model = LlamaForCausalLM + + elif config.llm_family == 'mistral': + logging.info('[INFO] LLM: Mistral') + llm_model = MistralForCausalLM + + elif config.llm_family == 'flan_t5': + logging.info('[INFO] LLM: Flan T5') + llm_model = T5ForConditionalGeneration + + elif config.llm_family == 'bart': + logging.info('[INFO] LLM: BART') + llm_model = BartForConditionalGeneration + else: + raise ValueError + + + llm_tokenizer = AutoTokenizer.from_pretrained( + config.llm_name, + use_fast=False, + token='your_token' + ) + # set the padding token to eos token for llama + if config.llm_family == 'llama': + llm_tokenizer.pad_token = llm_tokenizer.eos_token + + #________________________________ LLM Quantization ________________________________# + if config.llm_family in ['mistral', 'llama']: + dtype=None + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type='nf4', + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16 + ) + else: + if config.fp16: + dtype = torch.float16 + if config.llm_family == 'flan_t5': + dtype = torch.bfloat16 + else: + dtype = torch.float32 + quantization_config = None + + # llm_model.generate() + llm = llm_model.from_pretrained( + config.llm_name, + token='your_token', + torch_dtype=dtype, + quantization_config=quantization_config + ) + + if config.llm_family == 'llama': + llm_embed = llm.model.embed_tokens + elif config.llm_family == 'flan_t5': + llm_embed = llm.shared + elif config.llm_family == 'mistral': + llm_embed = llm.model.embed_tokens + elif config.llm_family == 'bart': + llm_embed = llm.model.shared + else: + raise ValueError + + # llm.resize_token_embeddings(len(self.tokenizer)) + if quantization_config is not None: + # Gradient checkpointing is not compatible with DDP!! + llm = prepare_model_for_kbit_training(llm, use_gradient_checkpointing=True) + + + if config.freeze_llm: + for _, param in llm.named_parameters(): + param.requires_grad = False + logging.info('[INFO] LLM frozen') + else: + if config.use_lora_llm: + # load the lora config + with open(config.lora_config, 'r') as f: + lora_config = json.load(f) + + if config.llm_family in ['llama', 'mistral']: + lora_config['target_modules'] = ['q_proj', 'v_proj'] + + elif config.llm_family in ['flan_t5']: + lora_config['target_modules'] = ['q', 'v'] + + lora_config = LoraConfig(**lora_config) + llm = get_peft_model(llm, lora_config) + + logging.info('[INFO] LLM hot with lora') + else: + logging.info('[INFO] LLM hot') + + logging.info('[INFO] LLM successfully loaded') + + for _, param in llm_embed.named_parameters(): + param.data = param.data.float() + param.requires_grad = True + + llm_to_moe = nn.Linear(llm.config.hidden_size, bert_config.hidden_size) + llm_to_moe.apply(self.init_weights) + + moe_to_llm = nn.Linear(bert_config.hidden_size, llm.config.hidden_size) + moe_to_llm.apply(self.init_weights) + + ################## 2. Select the backbone ViT ################## + logging.info('[INFO] Loading ViT in progress') + if config.freeze_vit: + # vit_precision = 'fp16' if config.fp16 else 'fp32' + logging.info(f'[INFO] ViT precision: {config.vit_precision}') + visual_encoder, ln_vision = self.init_vision_encoder( + config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, precision=config.vit_precision + ) + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + for name, param in ln_vision.named_parameters(): + param.requires_grad = False + ln_vision = ln_vision.eval() + ln_vision.train = disabled_train + logging.info('[INFO] ViT frozen') + + else: + vit_precision = 'fp32' + visual_encoder, ln_vision = self.init_vision_encoder( + config.vit_model, config.image_res, drop_path_rate=0, use_grad_checkpoint=False, vit_precision=vit_precision + ) + logging.info('[INFO] ViT hot') + logging.info('[INFO] ViT successfully loaded') + + ################## 3. Define the ViT-Expert communication Interface ################## + self.system_prompt = False + self.vit_token_pooling = config.vit_token_pooling + if self.vit_token_pooling: + vit_proj = nn.Linear( + 1408*4, bert_config.hidden_size + ) + else: + vit_proj = nn.Linear( + 1408, bert_config.hidden_size + ) + vit_proj.apply(self.init_weights) + + spatial_att = SpatialAttention(input_dim=bert_config.hidden_size) + temporal_att = TemporalAttention(input_dim=bert_config.hidden_size) + + spatial_att.apply(self.init_weights) + temporal_att.apply(self.init_weights) + + ################## 4. Define the Expert layers ################## + moe_layers = None + moe_norm = None + if config.use_moes: + moe_layers = [] + + for moe_layer_idx in range(config.num_moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + moe_layer = MoELayer( + bert_config.hidden_size, + bert_config.num_attention_heads, + expert_flag, + has_hist=True, + use_sep_spatial_temp_experts=config.use_sep_spatial_temp_experts + ) + + moe_layer.apply(self.init_weights) + moe_layers.append(moe_layer) + + logging.info(f'[INFO] {moe_layer_idx+1}/{config.num_moe_layers} MoE layers successfully loaded') + + moe_layers = nn.ModuleList(moe_layers) + moe_norm = nn.LayerNorm(bert_config.hidden_size) + + ################## 5. Define the projection layers for contrastive learning ################## + # temp_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + # spatial_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + # vision_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + # cap_proj = nn.Linear(bert_config.hidden_size, config.joint_dim) + + # temp_proj.apply(self.init_weights) + # spatial_proj.apply(self.init_weights) + # vision_proj.apply(self.init_weights) + # cap_proj.apply(self.init_weights) + + ################## 6. Define the pooler for matching loss ################## + # pooler = Pooler(bert_config.hidden_size) + # pooler.apply(self.init_weights) + + ################## 5. Attach the matching heads ################## + # stm_head = nn.Linear(bert_config.hidden_size, 2) + # vcm_head = nn.Linear(bert_config.hidden_size, 2) + # lm_head = nn.Linear(bert_config.hidden_size, len(tokenizer)) + + # stm_head.apply(self.init_weights) + # vcm_head.apply(self.init_weights) + # lm_head.apply(self.init_weights) + + temp = nn.Parameter(0.07 * torch.ones([])) + # temp = 0.07 + + # Attach the components to self + if self.config.embed_from_llm: + self.tokenizer = llm_tokenizer + self.text_embedding = llm_embed + else: + self.tokenizer = tokenizer + self.text_embedding = text_embedding + self.token_type_embedding = token_type_embedding + + self.llm = llm + self.llm_to_moe = llm_to_moe + self.moe_to_llm = moe_to_llm + self.visual_encoder = visual_encoder + self.ln_vision = ln_vision + self.vit_proj = vit_proj + self.moe_layers = moe_layers + self.moe_norm = moe_norm + self.spatial_att = spatial_att + self.temporal_att = temporal_att + # self.temp_proj = temp_proj + # self.spatial_proj = spatial_proj + # self.vision_proj = vision_proj + # self.cap_proj = cap_proj + # self.pooler = pooler + # self.stm_head = stm_head + # self.vcm_head = vcm_head + # self.lm_head = lm_head + self.temp = temp + + def construct_global_input(self, cap_ids, cap_attention_mask, hist_ids, hist_attention_mask, vid_feat_len, device): + # for video: [spatial_feats][temp_feats][cap_feats][hist_feats] + + batch_size = cap_ids.size(0) + special_toks_indices = { + '': 0, + '': 1, + '': 2, + } + + ids = [self.added_vocab['']] + [self.added_vocab['']] + [self.added_vocab['']] + ids += vid_feat_len * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + ids += vid_feat_len * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + ids += cap_ids.size(1) * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + ids += hist_ids.size(1) * [self.added_vocab['']] + + ids += [self.added_vocab['']] + special_toks_indices[''] = len(ids) - 1 + total_len = len(ids) + + ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + + ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_ids + ids[:, special_toks_indices[''] + 1: special_toks_indices['']] = hist_ids + + + mask = torch.ones((batch_size, total_len), device=device) + mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = cap_attention_mask + mask[:, special_toks_indices[''] + 1: special_toks_indices['']] = hist_attention_mask + + return ids, mask, special_toks_indices + + def construct_reg_labels(self, regress_ids, start_regress_idx, full_embeds, device): + + full_labels = torch.LongTensor(full_embeds.size(0), full_embeds.size(1)).fill_(-100).to(device) + + for i in range(regress_ids.size(0)): + + full_labels[i, start_regress_idx[i]: start_regress_idx[i] + regress_ids[i].size(-1)] = regress_ids[i] + # Add to the labels -- just before the response starts + full_labels[i, start_regress_idx[i] - 1] = self.tokenizer.eos_token_id + + # labels = regress_ids.masked_fill( + # regress_ids == self.tokenizer.pad_token_id, -100 + # ).to(device) + + # eos_from_cond = torch.LongTensor(labels.size(0), 1).fill_(self.tokenizer.eos_token_id).to(device) + # labels = torch.concat([eos_from_cond, labels], dim=1) + + # full_labels = torch.LongTensor(labels.size(0), full_len).fill_(-100).to(device) + + # full_labels[:, len_cond-1:] = labels + + return full_labels + + def rearrange_llm_input_decoder_only(self, input_embeds, output_emebds, input_mask, cap_mask, hist_mask, output_mask, spatial_feat_len): + ''' + Push all pads to the right + ''' + # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] + # ------------> [...][...][...][...][ans ...][-----pad-----] + + init_len = input_embeds.size(1) + output_emebds.size(1) + + # First, we compute the initial offset of the visual features + offset = 3 + spatial_feat_len + 1 + spatial_feat_len # --> input_embeds[offset] = h_ + + offset_embeds = input_embeds[:, :offset, :] + offset_mask = input_mask[:, :offset] + + rest_input_embdes = input_embeds[:, offset:, :] + rest_input_mask = input_mask[:, offset:] + + start_output_idx = [] + full_embeds = [] + full_masks = [] + + for i in range(input_embeds.size(0)): + output_emebd_i = output_emebds[i] + output_mask_i = output_mask[i] + + cap_mask_i = cap_mask[i] + len_cap_i = cap_mask_i.sum() + end_cap_i = len_cap_i + 1 # +1 for the token + + cap_embdes_i_to_keep = rest_input_embdes[i, :end_cap_i, :] + cap_mask_i_to_keep = rest_input_mask[i, :end_cap_i,] + cap_embeds_i_to_push = rest_input_embdes[i, end_cap_i:cap_mask_i.size(-1) + 1, :] # +1 for the token + cap_mask_i_to_push = rest_input_mask[i, end_cap_i: cap_mask_i.size(-1) + 1] # +1 for the token + + hist_mask_i = hist_mask[i] + len_hist_i = hist_mask_i.sum() + start_hist_i = cap_mask_i.size(-1) + 1 + end_hist_i = start_hist_i + len_hist_i + 1 # +1 for token + + # fianl token to keep is which is the last in input_embdes/rest_input_embdes + final_tok_embedding_i = rest_input_embdes[i, -1, :].unsqueeze(0) + final_tok_mask_i = rest_input_mask[i, -1].unsqueeze(0) + + hist_embdes_i_to_keep = rest_input_embdes[i, start_hist_i:end_hist_i, :] + hist_mask_i_to_keep = rest_input_mask[i, start_hist_i:end_hist_i] + + # these two do not consider the last token --> we don't need to extra remove it from them + hist_embdes_i_to_push = rest_input_embdes[i, end_hist_i: cap_mask_i.size(-1) + 1 + hist_mask_i.size(-1) + 1, :] + hist_mask_i_to_push = rest_input_mask[i, end_hist_i: cap_mask_i.size(-1) + 1 + hist_mask_i.size(-1) + 1] + + full_embed_i = torch.cat( + [cap_embdes_i_to_keep, hist_embdes_i_to_keep, final_tok_embedding_i, output_emebd_i, cap_embeds_i_to_push, hist_embdes_i_to_push], + dim=0 + ) + + full_mask_i = torch.cat( + [cap_mask_i_to_keep, hist_mask_i_to_keep, final_tok_mask_i, output_mask_i, cap_mask_i_to_push, hist_mask_i_to_push], + dim=0 + ) + + start_output_idx.append(offset + cap_embdes_i_to_keep.size(0) + hist_embdes_i_to_keep.size(0) + 1 - 1) + + full_embeds.append(full_embed_i) + full_masks.append(full_mask_i) + + # Now stack to get the batch + full_embeds = torch.stack(full_embeds, dim=0) + full_masks = torch.stack(full_masks, dim=0) + + # Add the offset visual features + full_embeds = torch.cat([offset_embeds, full_embeds], dim=1) + full_masks = torch.cat([offset_mask, full_masks], dim=1) + + final_len = full_embeds.size(1) + + # Sanity check + assert init_len == final_len, 'The reconstructed embeds have length ({}) which is not the same as the length of initial embeds ({})'.format( + final_len, init_len + ) + + return full_embeds, full_masks, start_output_idx + + def pad_to_right_enc_dec(self, cap_embeds, cap_masks, hist_embeds, hist_masks, device): + """ + pushes all in-between pad tokens to the right + """ + res_embeds = [] + res_mask = [] + for cap_embed, cap_mask, hist_embed, hist_mask in zip(cap_embeds, cap_masks, hist_embeds, hist_masks): + len_cap = sum(cap_mask) + len_hist = sum(hist_mask) + + batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) + batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) + batch_mask[:len_cap+len_hist] = 1 + + res_embeds.append(batch_embed) + res_mask.append(batch_mask) + + res_embeds = torch.stack(res_embeds, dim=0) + res_mask = torch.stack(res_mask, dim=0) + + return res_embeds, res_mask + + def pad_to_right_dec_only(self, cap_embeds, cap_masks, hist_embeds, hist_masks, regress_embeds, regress_masks, device): + """ + pushes all in-between pad tokens to the right + """ + res_embeds = [] + res_mask = [] + regress_limits_txt_input = [] + for cap_embed, cap_mask, hist_embed, hist_mask, regress_emebd, regress_mask in zip( + cap_embeds, cap_masks, hist_embeds, hist_masks, regress_embeds, regress_masks): + + len_cap = sum(cap_mask) + len_hist = sum(hist_mask) + len_ans = sum(regress_mask) + regress_limits_txt_input.append((len_cap+len_hist, len_cap+len_hist+len_ans)) + + batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], regress_emebd, cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) + batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) + batch_mask[:len_cap+len_hist+len_ans] = 1 + + res_embeds.append(batch_embed) + res_mask.append(batch_mask) + + res_embeds = torch.stack(res_embeds, dim=0) + res_mask = torch.stack(res_mask, dim=0) + + return res_embeds, res_mask, regress_limits_txt_input + + def pad_to_right_dec_only_gen_mode(self, cap_embeds, cap_masks, hist_embeds, hist_masks, device): + """ + pushes all in-between pad tokens to the right + """ + res_embeds = [] + res_mask = [] + for cap_embed, cap_mask, hist_embed, hist_mask in zip(cap_embeds, cap_masks, hist_embeds, hist_masks): + + len_cap = sum(cap_mask) + len_hist = sum(hist_mask) + + batch_embed = torch.cat([cap_embed[:len_cap], hist_embed[:len_hist], cap_embed[len_cap:], hist_embed[len_hist:]], dim=0) + batch_mask = torch.zeros(batch_embed.size(0)).long().to(device) + batch_mask[:len_cap+len_hist] = 1 + + res_embeds.append(batch_embed) + res_mask.append(batch_mask) + + res_embeds = torch.stack(res_embeds, dim=0) + res_mask = torch.stack(res_mask, dim=0) + + return res_embeds, res_mask + + def encode_vis_with_seq_spa_temp_att(self, image, device, is_vid=True): + num_frames = image.size(1) + bs_pre_reshape = image.size(0) + if len(image.shape) > 4: + image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + # with self.maybe_autocast(): # inherited from Blip2Base + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + + bs, pn, hs = image_embeds.shape + if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) + image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + vis_embed = self.vit_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096) + + # reshape the video features + vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) + size_orig = vis_embed.size() + + # Perfrom spatial temporal attention + vis_embed = self.spatial_att(vis_embed) + if is_vid: + vis_embed = vis_embed.view(size_orig) + vis_embed = self.temporal_att(vis_embed) + + vis_feat_len = vis_embed.size(1) + + # vis_embed = vis_embed + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) + vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + return vis_embed, vis_mask + + def moe_forward_no_sep_spatial_temporal( + self, + vis, vis_mask, + cap_ids, cap_mask, hist_ids, hist_mask, + is_vid, device): + + # is_vid = media_type == 'webvid' + # batch_size = len(cap) + vis_feat_len = vis.size(1) + input_embeds = [] + input_masks = [] + + input_embeds.append(vis) + input_masks.append(vis_mask) + + # if is_vid: + # input_embeds.append(vis_temporal) + # input_masks.append(vis_temporal_mask) + + if self.config.embed_from_llm: + cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) + else: + cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) + + cap_feat_len = cap_embeds.size(1) + + input_embeds.append(cap_embeds) + input_masks.append(cap_mask) + + if self.config.embed_from_llm: + hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) + else: + hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2)) + + hist_feat_len = hist_embeds.size(1) + + input_embeds.append(hist_embeds) + input_masks.append(hist_mask) + + input_embeds = torch.cat(input_embeds, dim=1) + input_masks = torch.cat(input_masks, dim=1) + + # expand the mask + input_masks = self.get_extended_attention_mask(attention_mask=input_masks) + + # MoEs feed-forward + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + + input_embeds = moe_layer(input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len, is_vid=is_vid, mask=input_masks) + + #TODO normalize the output () !!!!!! + input_embeds = self.moe_norm(input_embeds) + + # return the features + vis_embeds = input_embeds[:, :vis_feat_len] + # temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None + cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len] + hist_embeds = input_embeds[:, -hist_feat_len:] + # cls_feats = self.pooler(cap_feats) + + moe_outputs = { + 'vis_embeds': vis_embeds, + # 'temporal_embeds': temporal_embeds, + 'cap_embeds': cap_embeds, + 'hist_embeds': hist_embeds, + # 'cls_feats': cls_feats, + # 'last_hidden': input_embeds + } + + return moe_outputs + + def moe_forward( + self, + vis_spatial, vis_spatial_mask, vis_temporal, vis_temporal_mask, + cap_ids, cap_mask, hist_ids, hist_mask, + is_vid, device): + + # is_vid = media_type == 'webvid' + # batch_size = len(cap) + vis_feat_len = vis_spatial.size(1) + input_embeds = [] + input_masks = [] + + input_embeds.append(vis_spatial) + input_masks.append(vis_spatial_mask) + + if is_vid: + input_embeds.append(vis_temporal) + input_masks.append(vis_temporal_mask) + + if self.config.embed_from_llm: + cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) + else: + cap_embeds = self.text_embedding(cap_ids) + self.token_type_embedding(torch.ones_like(cap_ids).long().fill_(2)) + + cap_feat_len = cap_embeds.size(1) + + input_embeds.append(cap_embeds) + input_masks.append(cap_mask) + + if self.config.embed_from_llm: + hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) + else: + hist_embeds = self.text_embedding(hist_ids) + self.token_type_embedding(torch.ones_like(hist_ids).long().fill_(2)) + + hist_feat_len = hist_embeds.size(1) + + input_embeds.append(hist_embeds) + input_masks.append(hist_mask) + + input_embeds = torch.cat(input_embeds, dim=1) + input_masks = torch.cat(input_masks, dim=1) + + # expand the mask + input_masks = self.get_extended_attention_mask(attention_mask=input_masks) + + # MoEs feed-forward + for moe_layer_idx, moe_layer in enumerate(self.moe_layers): + if moe_layer_idx < self.config.num_moe_modality_layers: + expert_flag = 'modalities' + else: + expert_flag = 'fusion' + + input_embeds = moe_layer( + input_embeds, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len, + is_vid=is_vid, + mask=input_masks, + expert_permutation=self.config.expert_permutation + ) + + #TODO normalize the output () !!!!!! + input_embeds = self.moe_norm(input_embeds) + + # return the features + spatial_embeds = input_embeds[:, :vis_feat_len] + temporal_embeds = input_embeds[:, vis_feat_len:2*vis_feat_len] if is_vid else None + cap_embeds = input_embeds[:, -(cap_feat_len + hist_feat_len): -hist_feat_len] + hist_embeds = input_embeds[:, -hist_feat_len:] + # cls_feats = self.pooler(cap_feats) + + moe_outputs = { + 'spatial_embeds': spatial_embeds, + 'temporal_embeds': temporal_embeds, + 'cap_embeds': cap_embeds, + 'hist_embeds': hist_embeds, + # 'cls_feats': cls_feats, + # 'last_hidden': input_embeds + } + + return moe_outputs + + def forward(self, vis, cap, hist, ans, media_type): + + device = vis.device + is_vid = media_type in ['webvid', 'champagne', 'avsd', 'nextqa'] + loss_stc = torch.tensor(0) + loss_stm = torch.tensor(0) + loss_vhc = torch.tensor(0) + loss_vhm = torch.tensor(0) + loss_gen = torch.tensor(0) + + # construct the global input tensor --> use place holder for vis features + cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=None) + hist_ids, hist_mask = self.tokenize_text(hist, device, max_len=None) + if self.config.use_moes: + # First get the visual features depending on the media type + if self.config.use_sep_spatial_temp_experts: + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = self.encode_vis(vis, device, is_vid=is_vid) + spatial_feat_len = vis_embed_spatial.size(1) + + else: + vis_embed, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + + if self.config.use_sep_spatial_temp_experts: + moe_outputs = self.moe_forward( + vis_embed_spatial, vis_spatial_mask, + vis_embed_temporal, vis_temporal_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + spatial_embeds = self.moe_to_llm(moe_outputs['spatial_embeds']) + temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + # cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) + # hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) + + else: + moe_outputs = self.moe_forward_no_sep_spatial_temporal( + vis_embed, vis_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + vis_embeds = self.moe_to_llm(moe_outputs['vis_embeds']) + # temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) + hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) + else: + cap_embeds = self.llm_to_moe(self.text_embedding(cap_ids)) + hist_embeds = self.llm_to_moe(self.text_embedding(hist_ids)) + + vis_embeds, vis_mask = self.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + ans = [a + self.tokenizer.eos_token for a in ans] + + if self.config.llm_family in ['llama', 'mistral']: + bos = torch.ones_like(cap_ids[:, :1]) * self.tokenizer.bos_token_id + bos_embeds = self.text_embedding(bos) + bos_mask = cap_mask[:, :1] + + # add corresponding eos + + regress_ids, regress_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest + + regress_embeds = self.text_embedding(regress_ids) + + inputs_embeds, attention_mask, regress_limits_txt_input = self.pad_to_right_dec_only(cap_embeds, cap_mask, hist_embeds, hist_mask, regress_embeds, regress_mask, device) + + if is_vid: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) + + labels = torch.zeros(inputs_embeds.size()[:-1]).fill_(-100).long().to(device) + + for i in range(labels.size(0)): + start_regress = regress_limits_txt_input[i][0] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) + end_regress = regress_limits_txt_input[i][1] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) + + labels[i, start_regress:end_regress] = regress_ids[i, :regress_mask[i].sum()] + + + # get causal attention mask + + + # Compute the regression embeds + + # Now we need to right-pad the input to LLM (at least for llama) to avoid nan loss values + # This means, all pad tokens have to be placed to the right + # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] + # ------------> [...][...][...][...][ans ...][-----pad-----] + + # full_embeds, full_masks, start_output_idx = self.rearrange_llm_input_dec_only(cond_embeds, regress_embeds, cond_mask, cap_mask, hist_mask, regress_mask, spatial_feat_len) + + # labels = self.construct_reg_labels(regress_ids, start_output_idx, full_embeds, device) + + lm_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + return_dict=True + ) + loss_gen = lm_outputs.loss + + # Encoder Decoder + else: + inputs_embeds, attention_mask = self.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + + # now merge the multi-modal inputs + if self.config.use_moes: + if self.config.use_sep_spatial_temp_experts: + if is_vid: + inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + + decoder_ids, decoder_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest + + labels = decoder_ids.masked_fill(decoder_ids == self.tokenizer.pad_token_id, -100) + decoder_ids = self.shift_right(labels) + decoder_inputs_embeds = self.text_embedding(decoder_ids) + + lm_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_inputs_embeds=decoder_inputs_embeds, + decoder_attention_mask=decoder_mask, + labels=labels, + return_dict=True + ) + + loss_gen = lm_outputs.loss + + return dict( + loss_stc = loss_stc * self.config.loss_dict['stc'], + loss_stm = loss_stm * self.config.loss_dict['stm'], + loss_vhc = loss_vhc * self.config.loss_dict['vhc'], + loss_vhm = loss_vhm * self.config.loss_dict['vhm'], + loss_gen = loss_gen * self.config.loss_dict['gen'], + ) + + +class V2DialNoMoes(V2Dial): + def __init__(self, config): + super(V2DialNoMoes, self).__init__(config) + + def encode_vis(self, image, device, is_vid=True): + num_frames = image.size(1) + bs_pre_reshape = image.size(0) + if len(image.shape) > 4: + image = image.view(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224) + # with self.maybe_autocast(): # inherited from Blip2Base + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408) + image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408) + + bs, pn, hs = image_embeds.shape + if self.vit_token_pooling: # concat the each 4 tokens into one token (200,64,5632) + image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632) + + vis_embed = self.vit_proj(image_embeds) # project to LLM input size (200,64,5632) -> (200,64, d_hidden) + + # reshape the video features + vis_embed = vis_embed.view(bs_pre_reshape, num_frames, -1, vis_embed.size(-1)) + + + # Perfrom spatial temporal attention + if is_vid: + vis_embed = self.temporal_att(vis_embed) + if not self.config.embed_from_llm: + vis_embed_temporal = vis_embed_temporal + self.token_type_embedding(torch.ones(bs_pre_reshape, vis_feat_len).long().to(device)) + # vis_temporal_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + vis_embed = self.spatial_att(vis_embed) + vis_feat_len = vis_embed_spatial.size(1) + + if not self.config.embed_from_llm: + vis_embed_spatial = vis_embed_spatial + self.token_type_embedding(torch.zeros(bs_pre_reshape, vis_feat_len).long().to(device)) + vis_mask = torch.ones((bs_pre_reshape, vis_feat_len)).to(device) + + return vis_embed, vis_mask + + + def forward(self, vis, cap, hist, ans, media_type): + + device = vis.device + is_vid = media_type in ['webvid', 'champagne', 'avsd', 'nextqa'] + loss_stc = torch.tensor(0) + loss_stm = torch.tensor(0) + loss_vhc = torch.tensor(0) + loss_vhm = torch.tensor(0) + loss_gen = torch.tensor(0) + + # First get the visual features depending on the media type + vis_embed, vis_mask = self.encode_vis(vis, device, is_vid=is_vid) + + # spatial_feat_len = vis_embed_spatial.size(1) + + # construct the global input tensor --> use place holder for vis features + # text = (c + h for c,h in zip(cap, hist)) + # cap_ids, cap_mask = self.tokenize_text(cap, device, max_len=None) + # hist_ids, hist_mask = self.tokenize_text(hist, device, max_len=None) + # text_ids, text_mask = self.tokenize_text(text, device, max_len=None) + + text_embeds = self.text_embedding(text_ids) + # moe_outputs = self.moe_forward( + # vis_embed_spatial, vis_spatial_mask, + # vis_embed_temporal, vis_temporal_mask, + # cap_ids, cap_mask, + # hist_ids, hist_mask, + # is_vid, device + # ) + # spatial_embeds = self.moe_to_llm(moe_outputs['spatial_embeds']) + # temporal_embeds = self.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + # cap_embeds = self.moe_to_llm(moe_outputs['cap_embeds']) + # hist_embeds = self.moe_to_llm(moe_outputs['hist_embeds']) + + ans = [a + self.tokenizer.eos_token for a in ans] + + if self.config.llm_family in ['llama', 'mistral']: + bos = torch.ones_like(cap_ids[:, :1]) * self.tokenizer.bos_token_id + bos_embeds = self.text_embedding(bos) + bos_mask = cap_mask[:, :1] + + # add corresponding eos + + regress_ids, regress_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest + + regress_embeds = self.text_embedding(regress_ids) + + inputs_embeds, attention_mask, regress_limits_txt_input = self.pad_to_right_dec_only(cap_embeds, cap_mask, hist_embeds, hist_mask, regress_embeds, regress_mask, device) + + if is_vid: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) + + labels = torch.zeros(inputs_embeds.size()[:-1]).fill_(-100).long().to(device) + + for i in range(labels.size(0)): + start_regress = regress_limits_txt_input[i][0] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) + end_regress = regress_limits_txt_input[i][1] + 1 + spatial_feat_len + spatial_feat_len * int(is_vid) # offset (bos + spatial + temporal) + + labels[i, start_regress:end_regress] = regress_ids[i, :regress_mask[i].sum()] + + + # get causal attention mask + + + # Compute the regression embeds + + # Now we need to right-pad the input to LLM (at least for llama) to avoid nan loss values + # This means, all pad tokens have to be placed to the right + # full_embeds = [...][...][...][pad][...][pad][ans ...][pad] + # ------------> [...][...][...][...][ans ...][-----pad-----] + + # full_embeds, full_masks, start_output_idx = self.rearrange_llm_input_dec_only(cond_embeds, regress_embeds, cond_mask, cap_mask, hist_mask, regress_mask, spatial_feat_len) + + # labels = self.construct_reg_labels(regress_ids, start_output_idx, full_embeds, device) + + lm_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + return_dict=True + ) + loss_gen = lm_outputs.loss + + # Encoder Decoder + else: + # inputs_embeds, attention_mask = self.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + + # now merge the multi-modal inputs + # if is_vid: + # inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + # attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + # else: + inputs_embeds = torch.cat([vis_embed, text_embeds], dim=1) + attention_mask = torch.cat([vis_mask, text_mask], dim=1) + + decoder_ids, decoder_mask = self.tokenize_text(ans, device, max_len=None) # pad the longest + + labels = decoder_ids.masked_fill(decoder_ids == self.tokenizer.pad_token_id, -100) + decoder_ids = self.shift_right(labels) + decoder_inputs_embeds = self.text_embedding(decoder_ids) + + lm_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_inputs_embeds=decoder_inputs_embeds, + decoder_attention_mask=decoder_mask, + labels=labels, + return_dict=True + ) + + loss_gen = lm_outputs.loss + + return dict( + loss_stc = loss_stc * self.config.loss_dict['stc'], + loss_stm = loss_stm * self.config.loss_dict['stm'], + loss_vhc = loss_vhc * self.config.loss_dict['vhc'], + loss_vhm = loss_vhm * self.config.loss_dict['vhm'], + loss_gen = loss_gen * self.config.loss_dict['gen'], + ) \ No newline at end of file diff --git a/processors/__init__.py b/processors/__init__.py new file mode 100755 index 0000000..8b13789 --- /dev/null +++ b/processors/__init__.py @@ -0,0 +1 @@ + diff --git a/processors/base_processor.py b/processors/base_processor.py new file mode 100755 index 0000000..39b33cd --- /dev/null +++ b/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/processors/blip_processors.py b/processors/blip_processors.py new file mode 100755 index 0000000..b6c3929 --- /dev/null +++ b/processors/blip_processors.py @@ -0,0 +1,214 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re +import torch +from processors.base_processor import BaseProcessor +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + + segment_mean = (0.485, 0.456, 0.406) + segment_std = (0.229, 0.224, 0.225) + + self.normalize = transforms.Normalize(segment_mean, segment_std) + + +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#|:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +class BlipDialogProcessor(BlipCaptionProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def pre_caption_rm_period(self, text): + text = re.sub( + r"([.!\"()*#|:;~])", + " ", + text.lower(), + ) + text = re.sub( + r"\s{2,}", + " ", + text, + ) + text = text.rstrip("\n") + text = text.strip(" ") + + # truncate caption + text_words = text.split(" ") + if len(text_words) > self.max_words: + text = " ".join(text_words[: self.max_words]) + return text + + def pre_caption(self, text): + text = re.sub( + r"([\"()*#|:;~])", + " ", + text.lower(), + ) + text = re.sub( + r"\s{2,}", + " ", + text, + ) + text = text.rstrip("\n") + text = text.strip(" ") + + # truncate caption + text_words = text.split(" ") + if len(text_words) > self.max_words: + text = " ".join(text_words[: self.max_words]) + return text + + def __call__(self, caption, remove_period=False): + if remove_period: + caption = self.prompt + self.pre_caption_rm_period(caption) + else: + caption = self.prompt + self.pre_caption(caption) + return caption + + +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + # self.transform = transforms.Compose( + # [ + # transforms.RandomResizedCrop( + # image_size, + # scale=(min_scale, max_scale), + # interpolation=InterpolationMode.BICUBIC, + # ), + # transforms.ToTensor(), + # self.normalize, + # ] + # ) + self.transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC, antialias=True + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + + + # ### segment anything + # ''' + # x = (x - self.pixel_mean) / self.pixel_std + + # # Pad + # h, w = x.shape[-2:] + # padh = self.image_encoder.img_size - h + # padw = self.image_encoder.img_size - w + # x = F.pad(x, (0, padw, 0, padh)) + # ''' + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) \ No newline at end of file diff --git a/processors/randaugment.py b/processors/randaugment.py new file mode 100755 index 0000000..7034a49 --- /dev/null +++ b/processors/randaugment.py @@ -0,0 +1,398 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import cv2 +import numpy as np + +import torch + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32( + [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] + ) * factor + np.float32([[0.114], [0.587], [0.299]]) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = ( + np.array([(el - mean) * factor + mean for el in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert ( + frames.shape[-1] == 3 + ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." + + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + + num_frames = frames.shape[0] + + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + + frames = torch.stack( + list(map(self._aug, frames, ops, apply_or_not)), dim=0 + ).float() + + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return torch.from_numpy(img) + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/tasks/pre_train.py b/tasks/pre_train.py new file mode 100644 index 0000000..c39daa7 --- /dev/null +++ b/tasks/pre_train.py @@ -0,0 +1,413 @@ +import os +import datetime +import wandb +import torch +import pandas as pd +from time import time + +import torch.distributed as dist +from torch.distributed import ReduceOp + +from torch.nn.utils.clip_grad import clip_grad_value_ +from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts +from datasets.utils import get_datasets_media +from datasets.dataloader import MetaLoader +from utils.dist import is_main_process, get_rank, get_world_size +from utils.logger import setup_wandb, log_dict_to_wandb +from .retrieval_utils import evaluation_wrapper +import glog as logger + + +def run_epoch( + model, + train_dataloaders, + optimizer, + epoch, + global_step, + webvid_step, + cc3m_step, + device, + scheduler, + scaler, + config +): + model.train() + media_types = list(train_dataloaders.keys()) + + log_freq = config['log_freq'] + # metric_logger = MetricLogger(delimiter=' ') + # metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}')) + # metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}")) + + loss_names = ['loss_' + k for k in config['loss_dict'].keys()] + # for l in loss_names: + # for m in media_types: + # metric_logger.add_meter( + # f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}") + # ) + + + # header = '{} | Epoch = {}'.format(config['stage'], epoch) + + model_without_ddp = model + if config['distributed']: + model_without_ddp = model.module + for k in train_dataloaders: + train_dataloaders[k].sampler.set_epoch(epoch) + + train_dataloader = MetaLoader(name2loader=train_dataloaders) + + log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n' + log_text_template += '[Losses] mlm (x{}) = {:.4f} | vcc (x{}) = {:.4f} | vcm (x{}) = {:.4f} | stc (x{}) = {:.4f} | stm (x{}) = {:.4f}\n' + log_text_template += '[Other] lr = {:.4f} | temp = {:.4f} | eta = {}\n' + + # iterator = metric_logger.log_every(train_dataloader, log_freq, header) + local_step = 0 + for media_type, (vis, caption, neg_vis) in train_dataloader: + start = time() + # loss_dict = {} + vis = vis.to(device) + neg_vis = neg_vis.to(device) + # idx = idx.to(device) + + with torch.cuda.amp.autocast(enabled=config.fp16): + loss_dict = model(vis, caption, neg_vis, media_type) + # loss_dict.update(losses) + loss = sum(loss_dict.values()) + loss_accum_grad = loss / config.accum_grad_every + + scaler.scale(loss_accum_grad).backward() + + # Perfrom gradient clipping: unscale --> clip + if config['clip_grad_value'] > 0: + # scaler.unscale_(optimizer) + clip_grad_value_(model.parameters(), config.clip_grad_value) + + if local_step % config.accum_grad_every == 0: + scaler.step(optimizer) + scaler.update() + # scheduler.step(epoch, global_step) + scheduler.step() + optimizer.zero_grad() + + time_iter = time() - start + eta = (len(train_dataloader) - local_step - 1) * time_iter + eta = str(datetime.timedelta(seconds=eta)) + # log + log_dict_webvid = {} + log_dict_cc3m = {} + log_dict_rest = {} + for loss_name in loss_names: + value = loss_dict[loss_name] + value = value if isinstance(value, float) else value.item() + # metric_logger.update(**{f"{media_type}/{loss_name}": value}) + if media_type == "cc3m": + log_dict_cc3m[f"train/{media_type}/{loss_name}"] = value + else: + log_dict_webvid[f"train/{media_type}/{loss_name}"] = value + + # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # metric_logger.update(temperature=model_without_ddp.temp.item()) + log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"] + log_dict_rest['train/other/temperature'] = model_without_ddp.temp.item() + + if is_main_process() and global_step % log_freq == 0 and local_step % config.accum_grad_every == 0: + log_dict_rest['train/other/step'] = global_step + if media_type == 'cc3m': + log_dict_cc3m['train/cc3m/step'] = cc3m_step + + log_text = log_text_template.format( + epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, + config.loss_dict['mlm'], log_dict_cc3m['train/cc3m/loss_mlm'], + config.loss_dict['vcc'], log_dict_cc3m['train/cc3m/loss_vcc'], + config.loss_dict['vcm'], log_dict_cc3m['train/cc3m/loss_vcm'], + config.loss_dict['stc'], log_dict_cc3m['train/cc3m/loss_stc'], + config.loss_dict['stm'], log_dict_cc3m['train/cc3m/loss_stc'], + log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta + ) + logger.info(log_text) + + if config['wandb_enabled']: + wandb.log(log_dict_rest) + wandb.log(log_dict_cc3m) + # log_text_template = '[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n' + # log_text_template += '[losses: mlm = {:.4f} | vcc = {:4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f}]\n' + # log_text_template += '[Other: lr = {:.4f} | temp = {:4f}]\n' + + else: + log_dict_webvid['train/webvid/step'] = webvid_step + log_text = log_text_template.format( + epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, + config.loss_dict['mlm'], log_dict_webvid['train/webvid/loss_mlm'], + config.loss_dict['vcc'], log_dict_webvid['train/webvid/loss_vcc'], + config.loss_dict['vcm'], log_dict_webvid['train/webvid/loss_vcm'], + config.loss_dict['stc'], log_dict_webvid['train/webvid/loss_stc'], + config.loss_dict['stm'], log_dict_webvid['train/webvid/loss_stm'], + log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta + ) + logger.info(log_text) + + if config['wandb_enabled']: + wandb.log(log_dict_rest) + wandb.log(log_dict_webvid) + + + if media_type == "cc3m": + cc3m_step += 1 + else: + webvid_step += 1 + global_step += 1 + local_step += 1 + # gather the stats from all processes + # metric_logger.synchronize_between_processes() + # logger.info(f"Averaged stats: {metric_logger.global_avg()}") + + return global_step, webvid_step, cc3m_step + + +def eval(model, val_dataloader, device, epoch, config): + + model.eval() + + log_text_template = '\n' + '-' * 25 + '\n[Val Epoch{}][Iter. {}/{}][Media-type {}]\n' + log_text_template += '[Losses] mlm = {:.4f} | vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} \n' + + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + + cum_loss_stc = 0 + cum_loss_stm = 0 + cum_loss_vcc = 0 + cum_loss_vcm = 0 + cum_loss_mlm = 0 + cum_loss_tot = 0 + val_step = 0 + + # val_dataloader = MetaLoader(name2loader=val_dataloaders) + media_type = val_dataloader.dataset.medium + + if is_main_process(): + start_time = time() + + # for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: + for vis, caption, neg_vis in val_dataloader: + # for vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: + vis = vis.to(device) + neg_vis = neg_vis.to(device) + # idx = idx.to(device) + + with torch.cuda.amp.autocast(enabled=config['fp16']): + with torch.no_grad(): + # loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type) + # loss_dict = model(vis, caption, neg_vis, neg_caption, media_type, file, neg_file) + loss_dict = model(vis, caption, neg_vis, media_type) + + loss = sum(loss_dict.values()) + loss_stc = loss_dict['loss_stc'] + loss_stm = loss_dict['loss_stm'] + loss_vcc = loss_dict['loss_vcc'] + loss_vcm = loss_dict['loss_vcm'] + loss_mlm = loss_dict['loss_mlm'] + + if config['distributed']: + dist.all_reduce(loss, op=ReduceOp.AVG) + if config.loss_dict['stc'] != 0: + dist.all_reduce(loss_stc, op=ReduceOp.AVG) + if config.loss_dict['stm'] != 0: + dist.all_reduce(loss_stm, op=ReduceOp.AVG) + if config.loss_dict['vcc'] != 0: + dist.all_reduce(loss_vcc, op=ReduceOp.AVG) + if config.loss_dict['vcm'] != 0: + dist.all_reduce(loss_vcm, op=ReduceOp.AVG) + if config.loss_dict['mlm'] != 0: + dist.all_reduce(loss_mlm, op=ReduceOp.AVG) + + if is_main_process(): + cum_loss_tot += loss.item() + cum_loss_stc += loss_stc.item() + cum_loss_stm += loss_stm.item() + cum_loss_vcc += loss_vcc.item() + cum_loss_vcm += loss_vcm.item() + cum_loss_mlm += loss_mlm.item() + + if val_step % config.log_freq == 0: + log_text = log_text_template.format( + epoch, val_step, len(val_dataloader), media_type, + loss_mlm, loss_vcc, loss_vcm, loss_stc, loss_stm) + # log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n' + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + # log_text = log_text_template.format( + # epoch, val_step, len(val_dataloader), media_type, + # loss_vcc, loss_vcm, loss_stc, loss_stm, 0, + # loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen + # ) + + logger.info(log_text) + # logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format( + # epoch, val_step, len(val_dataloader), gen_loss, loss + # )) + val_step += 1 + + if config['distributed']: + dist.barrier() + + if is_main_process(): + duration = time() - start_time + + cum_loss_tot /= len(val_dataloader) + cum_loss_stc /= len(val_dataloader) + cum_loss_stm /= len(val_dataloader) + cum_loss_vcc /= len(val_dataloader) + cum_loss_vcm /= len(val_dataloader) + cum_loss_mlm /= len(val_dataloader) + + # cum_loss_vhc /= len(val_dataloader) + # cum_loss_vhm /= len(val_dataloader) + # cum_loss_chc /= len(val_dataloader) + # cum_loss_chm /= len(val_dataloader) + # cum_loss_gen /= len(val_dataloader) + logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_total = {:.4f}'.format( + datetime.timedelta(seconds=int(duration)), cum_loss_tot + )) + + # logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format( + # datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot + # )) + + loss_dict = { + 'stc': cum_loss_stc, + 'stm': cum_loss_stm, + 'vcc': cum_loss_vcc, + 'vcm': cum_loss_vcm, + # 'vhc': cum_loss_vhc, + # 'vhm': cum_loss_vhm, + # 'chc': cum_loss_chc, + # 'chm': cum_loss_chm, + 'mlm': cum_loss_mlm, + # 'gen': cum_loss_gen, + 'tot': cum_loss_tot + } + return loss_dict + + +def pre_train( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + webvid_step, + cc3m_step, + scheduler, + scaler, + start_epoch, + config +): + if is_main_process() and config['wandb_enabled']: + run = setup_wandb(config) + setup_seed(config['seed'] + get_rank()) + device = torch.device('cuda:{}'.format(config['gpu'])) + + if is_main_process() and config['wandb_enabled']: + wandb.watch(model) + + best = float('inf') + best_epoch = 0 + + logger.info('[INFO] Start training...') + start_time_all = time() + for epoch in range(start_epoch, config['epochs']): + if not config['evaluate']: + start_time_epoch = time() + global_step, webvid_step, cc3m_step = run_epoch( + model, + train_dataloaders, + optimizer, + epoch, + global_step, + webvid_step, + cc3m_step, + device, + scheduler, + scaler, + config + ) + end_time_epoch = time() + epoch_time = end_time_epoch - start_time_epoch + epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) + logger.info(f'[INFO] Epoch took {epoch_time_str}') + + if not config['debugging']: + with torch.cuda.amp.autocast(enabled=config['fp16']): + # # TODO + # eval_res = {} + # for val_name, val_loader in val_dataloaders_dict.items(): + # res = evaluation_wrapper( + # model_without_ddp, val_loader, tokenizer, device, config, prefix=val_name + # ) + # eval_res.update(res) + val_res = {} + + for medium in val_dataloaders: + res = eval( + model, + val_dataloaders[medium], + device, + epoch, + config + ) + val_res[medium] = res + + if is_main_process(): + # Average across all datasets + avg_val_res = average_dicts(val_res) + # log to wandb + if config.wandb_enabled: + for medium in val_res: + log_dict_val = {} + # log_dict_val[f'val/{medium}/step'] = epoch + for l in val_res[medium]: + log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l] + wandb.log(log_dict_val) + # for p, v in eval_res.items(): + # log_dict_to_wandb(v, step=global_step, prefix=p) + if config.stop_key is not None and config.stop_key in avg_val_res: + cur_best = avg_val_res[config.stop_key] + else: # stop_key = None + cur_best = best - 1 # save the last as the best + + # Don't save vit weights as they are frozen + state_dict = model_without_ddp.state_dict() + state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k} + + save_obj = { + "model": state_dict, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "scaler": scaler.state_dict(), + "config": config, + "epoch": epoch, + "global_step": global_step, + } + torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth")) + + if not config.evaluate and cur_best < best: + torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth")) + # eval_file = "eval_res_best.json" + # eval_res.to_json(os.path.join(config.log_dir, eval_file)) + best = cur_best + + if config.evaluate: + break + if config['distributed']: + dist.barrier() + + total_time = time() - start_time_all + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f'[INFO] Training took {total_time_str}') + + if is_main_process() and config['wandb_enabled']: + run.finish() + diff --git a/tasks/retrieval_utils.py b/tasks/retrieval_utils.py new file mode 100644 index 0000000..be9192b --- /dev/null +++ b/tasks/retrieval_utils.py @@ -0,0 +1,435 @@ +import datetime +import logging +import time + +import numpy as np +import torch +import torch.distributed as dist +from einops import rearrange + +from models.criteria import get_sim +from utils.basic import MetricLogger +from utils.dist import get_rank, get_world_size + +logger = logging.getLogger(__name__) + + +def extract_text_feats(texts, max_txt_l, tokenizer, model, device): + num_text = len(texts) + text_bs = 256 + text_feats = [] + text_atts = [] + + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_txt_l, + return_tensors="pt", + ).to(device) + + text_feat = model.encode_text(text_input)[0] + text_feats.append(text_feat) + text_atts.append(text_input.attention_mask) + + text_feats = torch.cat(text_feats, dim=0) + text_atts = torch.cat(text_atts, dim=0) + return text_feats, text_atts + + +def extract_vision_feats(data_loader, model, device, config): + image_feats_all = [] + pooled_image_feats_all = [] + metric_logger = MetricLogger(delimiter=" ") + header = "extracting image feats" + iterator = metric_logger.log_every(data_loader, 100, header) + media_type = data_loader.dataset.medium + for vis, _ in iterator: + vis = vis.to(device, non_blocking=True) + vis_feat, pooled_vis_feat = model.get_vis_enc_for_eval(vis, media_type) + # if config.evaluation.eval_frame_ensemble == "concat": # default + # image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() + vis_feat = vis_feat.unsqueeze(1) # (bsz, 1, l, d) + # else: + # assert config.video_input.num_frames == 1, "only support single-frame" + # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] + if not config.eval_offload: + image_feats_all.append(vis_feat.cpu()) + pooled_image_feats_all.append(pooled_vis_feat.cpu()) + else: + image_feats_all.append(vis_feat) + pooled_image_feats_all.append(pooled_vis_feat) + + image_feats_all = torch.cat(image_feats_all, dim=0) + + pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) + return image_feats_all, pooled_image_feats_all + + +@torch.no_grad() +def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): + with torch.cuda.amp.autocast(enabled=config.fp16): + i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation( + model, data_loader, tokenizer, device, config + ) + score_pairs = [ + (prefix + "/", i2t_x, t2i_x), + (prefix + "_emb/", i2t_emb, t2i_emb), + ] + res = dict() + for name, i2t, t2i in score_pairs: + if i2t is not None: + txt2img_ids = data_loader.dataset.txt2vis + img2txt_ids = data_loader.dataset.vis2txt + res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids) + return res + + +@torch.no_grad() +def evaluation(model, data_loader, tokenizer, device, config): + model.eval() + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + dtype = torch.half if config.fp16 else torch.float + media_type = data_loader.dataset.medium + logger.info(f"Start evaluation for {media_type}") + + logger.info("Computing dual encoder features...") + start_time = time.time() + + # this computes all features in each GPU + texts = data_loader.dataset.text + max_txt_l = config.max_cap_len + + text_feats, text_atts = extract_text_feats( + texts, max_txt_l, tokenizer, model, device + ) # (bsz, Lt, d), (bsz, Lt) + + image_feats, pooled_image_feats = extract_vision_feats( + data_loader, model, device, config + ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d) + logger.info("Finished feature extraction") + logger.info("Computing ITC scores [dot-product]") + _pooled_image_feats = ( + pooled_image_feats.to(device, non_blocking=True) + if config.eval_offload + else pooled_image_feats + ) + i2t_scores, t2i_scores = get_sim( + model.vis_proj(_pooled_image_feats), model.cap_proj(text_feats[:, 0]) + ) + logger.info("Computing ITC scores [dot-product], done!") + + num_images = len(data_loader.dataset.vis) + i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to( + device, torch.float, non_blocking=True + ) + + # computes only part of the scores at each GPU, gather at the end + logger.info("Rerank dual-encoder results with cross-encoder...") + num_tasks = get_world_size() + rank = get_rank() + # only uses the part associated with the raw eval set + # compute image2text # + step = num_images // num_tasks + 1 + start = rank * step + end = min(num_images, start + step) + + text_encoder = model.get_expert_encoder('vis_cap_grounding') + + iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) + logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") + + # generate score for each clip, and aggregate all clip scores for a video + n_clip_per_video = 1 + # ( + # image_feats.shape[1] if not False else image_feats[0].shape[1] + # ) + + # logger.info( + # f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={'concat'}" + # ) + for i, sims in enumerate(iterator): + k = min(len(sims), config.eval_k_test) + topk_sim, topk_idx = sims.topk(k=k, dim=0) + + clip_scores = [] + for clip_idx in range(n_clip_per_video): + # if config.deep_fusion: + # encoder_output = [ + # feat[start + i, clip_idx].to(device, non_blocking=True) + # for feat in image_feats + # ] + + # else: + encoder_output = ( + image_feats[start + i, clip_idx].to(device, non_blocking=True) + if config.eval_offload + else image_feats[start + i, clip_idx] + ) # (#frm*Li, d) + + """ original + encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long + ).to(device, non_blocking=True) + output = text_encoder( + encoder_embeds=text_feats[topk_idx], + attention_mask=text_atts[topk_idx], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + mode="fusion" + ) + + itm_embeds = output.last_hidden_state[:, 0] + """ + + # new + bs = 32 + # bs = config.batch_size_test.video + itm_embeds = [] + + # if config.deep_fusion: + # encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] + # encoder_att = [ + # torch.ones(feat.size()[:-1], dtype=torch.long).to( + # device, non_blocking=True + # ) + # for feat in encoder_output + # ] + # else: + encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) + encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( + device, non_blocking=True + ) + + for j in range(0, len(topk_idx), bs): + output = text_encoder( + encoder_embeds=text_feats[topk_idx[j : j + bs]], + attention_mask=text_atts[topk_idx[j : j + bs]], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + batch_itm_embeds = output.last_hidden_state[:, 0] + itm_embeds.append(batch_itm_embeds) + itm_embeds = torch.cat(itm_embeds, dim=0) + # end new + + score = model.vcm_head(itm_embeds)[:, 1] + clip_scores.append(score) + + # if len(clip_scores) == 1: + score = clip_scores[0] + # else: + # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] + # clip_scores = torch.stack(clip_scores) # (#clips, k) + # if config.evaluation.eval_frame_ensemble == "mean": + # score = clip_scores.mean(0) + # elif config.evaluation.eval_frame_ensemble == "max": + # score = clip_scores.max(0)[0] + # elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp + # score = torch.logsumexp(clip_scores, dim=0) + # else: + # raise ValueError( + # "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." + # ) + + i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) + + # compute text2image # + num_text = len(data_loader.dataset.text) + t2i_scores_x = torch.full((num_text, len(data_loader.dataset.vis)), -100.0).to( + device, torch.float, non_blocking=True + ) + + step = num_text // num_tasks + 1 + start = rank * step + end = min(num_text, start + step) + + iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) + logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") + # generate score for each clip, and aggregate all clip scores for a video + n_clip_per_video = 1 + # ( + # image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] + # ) + for i, sims in enumerate(iterator): + k = min(len(sims), config.eval_k_test) + topk_sim, topk_idx = sims.topk(k=k, dim=0) + # topk_idx = + clip_scores = [] + for clip_idx in range(n_clip_per_video): + + """old + encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \ + if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long + ).to(device, non_blocking=True) + output = text_encoder( + encoder_embeds=text_feats[start+i].repeat(k, 1, 1), + attention_mask=text_atts[start+i].repeat(k, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + mode="fusion" + ) + + itm_embeds = output.last_hidden_state[:, 0] + """ + + # new + bs = 32 + # bs = config.batch_size_test.video + itm_embeds = [] + for j in range(0, len(topk_idx), bs): + + # if config.deep_fusion: + # encoder_output = [ + # feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) + # for feat in image_feats + # ] + # encoder_att = [ + # torch.ones(feat.size()[:-1], dtype=torch.long).to( + # device, non_blocking=True + # ) + # for feat in encoder_output + # ] + # else: + encoder_output = ( + image_feats[topk_idx[j : j + bs], clip_idx].to( + device, non_blocking=True + ) + if config.eval_offload + else image_feats[topk_idx[j : j + bs], clip_idx] + ) + encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( + device, non_blocking=True + ) + + repeat_n = ( + encoder_output.shape[0] + # if not config.deep_fusion + # else encoder_output[0].shape[0] + ) + output = text_encoder( + encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1), + attention_mask=text_atts[start + i].repeat(repeat_n, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + # mode="fusion", + ) + + batch_itm_embeds = output.last_hidden_state[:, 0] + itm_embeds.append(batch_itm_embeds) + + itm_embeds = torch.cat(itm_embeds, dim=0) + # end new + + score = model.vcm_head(itm_embeds)[:, 1] + clip_scores.append(score) + + # if len(clip_scores) == 1: + score = clip_scores[0] + # else: + # assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] + # clip_scores = torch.stack(clip_scores) # (#clips, k) + # if config.evaluation.eval_frame_ensemble == "mean": + # score = clip_scores.mean(0) + # elif config.evaluation.eval_frame_ensemble == "max": + # score = clip_scores.max(0)[0] + # elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp + # score = torch.logsumexp(clip_scores, dim=0) + # else: + # raise ValueError( + # "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." + # ) + + t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) + + if config.distributed: + # gether across GPUs + dist.barrier() + dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) + dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f"Evaluation time {total_time_str}") + + return ( + i2t_scores_x.cpu().numpy(), + t2i_scores_x.cpu().numpy(), + i2t_scores.cpu().numpy(), + i2t_scores.T.cpu().numpy(), + ) + + +@torch.no_grad() +def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): + # Images->Text + ranks = np.zeros(scores_i2t.shape[0]) + for index, score in enumerate(scores_i2t): + inds = np.argsort(score)[::-1] + # Score + gt_txt_ids = img2txt[index] + if isinstance(gt_txt_ids, int): + ranks[index] = np.where(inds == gt_txt_ids)[0][0] + else: + rank = 1e20 + for i in gt_txt_ids: + tmp = np.where(inds == i)[0][0] + if tmp < rank: + rank = tmp + ranks[index] = rank + + # Compute metrics + tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + # Text->Images + ranks = np.zeros(scores_t2i.shape[0]) + + for index, score in enumerate(scores_t2i): + inds = np.argsort(score)[::-1] + gt_img_ids = txt2img[index] + if isinstance(gt_img_ids, int): + ranks[index] = np.where(inds == gt_img_ids)[0][0] + else: # list, used in the case each caption has multiple GT images + # Score + rank = 1e20 + for i in gt_img_ids: + tmp = np.where(inds == i)[0][0] + if tmp < rank: + rank = tmp + ranks[index] = rank + + # Compute metrics + ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + tr_mean = (tr1 + tr5 + tr10) / 3 + ir_mean = (ir1 + ir5 + ir10) / 3 + r_mean = (tr_mean + ir_mean) / 2 + + eval_result = { + "txt_r1": tr1, + "txt_r5": tr5, + "txt_r10": tr10, + "txt_r_mean": tr_mean, + "vis_r1": ir1, + "vis_r5": ir5, + "vis_r10": ir10, + "vis_r_mean": ir_mean, + "r_mean": r_mean, + } + eval_result = {k: round(v, 2) for k, v in eval_result.items()} + return eval_result diff --git a/tasks/stage_2.py b/tasks/stage_2.py new file mode 100644 index 0000000..e083d16 --- /dev/null +++ b/tasks/stage_2.py @@ -0,0 +1,373 @@ +import os +import datetime +import wandb +import torch +import pandas as pd +from time import time + +import torch.distributed as dist +from torch.distributed import ReduceOp + +from torch.nn.utils.clip_grad import clip_grad_value_ +from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts +from datasets.utils import get_datasets_media +from datasets.dataloader import MetaLoader +from utils.dist import is_main_process, get_rank, get_world_size +from utils.logger import setup_wandb, log_dict_to_wandb +from .retrieval_utils import evaluation_wrapper +import glog as logger + + +def run_epoch( + model, + train_dataloaders, + optimizer, + epoch, + global_step, + device, + scheduler, + scaler, + config +): + model.train() + media_types = list(train_dataloaders.keys()) + + log_freq = config['log_freq'] + # metric_logger = MetricLogger(delimiter=' ') + # metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}')) + # metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}")) + + loss_names = ['loss_' + k for k in config['loss_dict'].keys()] + # for l in loss_names: + # for m in media_types: + # metric_logger.add_meter( + # f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}") + # ) + + + # header = '{} | Epoch = {}'.format(config['stage'], epoch) + + model_without_ddp = model + if config['distributed']: + model_without_ddp = model.module + for k in train_dataloaders: + train_dataloaders[k].sampler.set_epoch(epoch) + + train_dataloader = MetaLoader(name2loader=train_dataloaders) + + log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n' + log_text_template += '[Losses] gen = {:.4f} | vhc = {:.4f} | vhm = {:.4f} | stc = {:.4f} | stm = {:.4f}\n' + log_text_template += '[Other] lr = {:.4f} | temp = {:.4f} | iter_time = {:.2f} | eta = {}\n' + + # iterator = metric_logger.log_every(train_dataloader, log_freq, header) + local_step = 0 + for media_type, (vis, caption, history, answer) in train_dataloader: + # for media_type, (vis, caption, neg_vis, neg_caption, idx) in train_dataloader: + + start = time() + # loss_dict = {} + vis = vis.to(device) + # neg_vis = neg_vis.to(device) + # idx = idx.to(device) + + with torch.cuda.amp.autocast(enabled=config.fp16): + loss_dict = model(vis, caption, history, answer, media_type) + loss = sum(loss_dict.values()) + loss_accum_grad = loss / config.accum_grad_every + + scaler.scale(loss_accum_grad).backward() + + # Perfrom gradient clipping: unscale --> clip + if config['clip_grad_value'] > 0: + scaler.unscale_(optimizer) + clip_grad_value_(model.parameters(), config.clip_grad_value) + + if local_step % config.accum_grad_every == 0: + scaler.step(optimizer) + scaler.update() + scheduler.step() + optimizer.zero_grad() + + time_iter = time() - start + eta = (len(train_dataloader) - local_step - 1) * time_iter + eta = str(datetime.timedelta(seconds=eta)) + # log + log_dict = {} + log_dict_rest = {} + for loss_name in loss_names: + value = loss_dict[loss_name] + value = value if isinstance(value, float) else value.item() + log_dict[f"train/{media_type}/{loss_name}"] = value + + # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # metric_logger.update(temperature=model_without_ddp.temp.item()) + log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"] + log_dict_rest['train/other/temperature'] = model_without_ddp.temp.item() + + if is_main_process() and global_step % log_freq == 0 and local_step % config.accum_grad_every == 0: + # log_dict['train/webvid/step'] = webvid_step + log_text = log_text_template.format( + epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, + log_dict['train/champagne/loss_gen'], log_dict['train/champagne/loss_vhc'], log_dict['train/champagne/loss_vhm'], + log_dict['train/champagne/loss_stc'], log_dict['train/champagne/loss_stm'], + log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], time_iter, eta + ) + logger.info(log_text) + log_dict_rest['train/other/step'] = global_step + log_dict['train/champagne/step'] = global_step + + if config['wandb_enabled']: + wandb.log(log_dict) + wandb.log(log_dict_rest) + + global_step += 1 + local_step += 1 + # gather the stats from all processes + # metric_logger.synchronize_between_processes() + # logger.info(f"Averaged stats: {metric_logger.global_avg()}") + + return global_step + + +def eval(model, val_dataloader, device, epoch, config): + + model.eval() + + log_text_template = '\n' + '-' * 25 + '\n[Val Epoch{}][Iter. {}/{}][Media-type {}]\n' + log_text_template += '[Losses] gen = {:.4f} | vhc = {:.4f} | vhm = {:.4f} | stc = {:.4f} | stm = {:.4f} \n' + + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + + cum_loss_stc = 0 + cum_loss_stm = 0 + cum_loss_vhc = 0 + cum_loss_vhm = 0 + cum_loss_gen = 0 + cum_loss_tot = 0 + val_step = 0 + + # val_dataloader = MetaLoader(name2loader=val_dataloaders) + media_type = val_dataloader.dataset.medium + + if is_main_process(): + start_time = time() + + # for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: + for vis, caption, history, answer in val_dataloader: + # for vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: + vis = vis.to(device) + # neg_vis = neg_vis.to(device) + # idx = idx.to(device) + + with torch.cuda.amp.autocast(enabled=config['fp16']): + with torch.no_grad(): + # loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type) + loss_dict = model(vis, caption, history, answer, media_type) + + loss = sum(loss_dict.values()) + loss_stc = loss_dict['loss_stc'] + loss_stm = loss_dict['loss_stm'] + loss_vhc = loss_dict['loss_vhc'] + loss_vhm = loss_dict['loss_vhm'] + loss_gen = loss_dict['loss_gen'] + + if config['distributed']: + dist.all_reduce(loss, op=ReduceOp.AVG) + if config.loss_dict['stc'] != 0: + dist.all_reduce(loss_stc, op=ReduceOp.AVG) + if config.loss_dict['stm'] != 0: + dist.all_reduce(loss_stm, op=ReduceOp.AVG) + if config.loss_dict['vhc'] != 0: + dist.all_reduce(loss_vhc, op=ReduceOp.AVG) + if config.loss_dict['vhm'] != 0: + dist.all_reduce(loss_vhm, op=ReduceOp.AVG) + if config.loss_dict['gen'] != 0: + dist.all_reduce(loss_gen, op=ReduceOp.AVG) + + if is_main_process(): + cum_loss_tot += loss.item() + cum_loss_stc += loss_stc.item() + cum_loss_stm += loss_stm.item() + cum_loss_vhc += loss_vhc.item() + cum_loss_vhm += loss_vhm.item() + cum_loss_gen += loss_gen.item() + + if val_step % config.log_freq == 0: + log_text = log_text_template.format( + epoch, val_step, len(val_dataloader), media_type, + loss_gen, loss_vhc, loss_vhm, loss_stc, loss_stm) + # log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n' + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + # log_text = log_text_template.format( + # epoch, val_step, len(val_dataloader), media_type, + # loss_vcc, loss_vcm, loss_stc, loss_stm, 0, + # loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen + # ) + + logger.info(log_text) + # logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format( + # epoch, val_step, len(val_dataloader), gen_loss, loss + # )) + val_step += 1 + + if config['distributed']: + dist.barrier() + + if is_main_process(): + duration = time() - start_time + + cum_loss_tot /= len(val_dataloader) + cum_loss_stc /= len(val_dataloader) + cum_loss_stm /= len(val_dataloader) + cum_loss_vhc /= len(val_dataloader) + cum_loss_vhm /= len(val_dataloader) + cum_loss_gen /= len(val_dataloader) + + # cum_loss_vhc /= len(val_dataloader) + # cum_loss_vhm /= len(val_dataloader) + # cum_loss_chc /= len(val_dataloader) + # cum_loss_chm /= len(val_dataloader) + # cum_loss_gen /= len(val_dataloader) + logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_total = {:.4f}'.format( + datetime.timedelta(seconds=int(duration)), cum_loss_tot + )) + + # logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format( + # datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot + # )) + + # switch back to training mode + model.train() + + loss_dict = { + 'stc': cum_loss_stc, + 'stm': cum_loss_stm, + 'vhc': cum_loss_vhc, + 'vhm': cum_loss_vhm, + # 'vhc': cum_loss_vhc, + # 'vhm': cum_loss_vhm, + # 'chc': cum_loss_chc, + # 'chm': cum_loss_chm, + 'gen': cum_loss_gen, + # 'gen': cum_loss_gen, + 'tot': cum_loss_tot + } + return loss_dict + + +def train( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + scheduler, + scaler, + start_epoch, + config +): + if is_main_process() and config['wandb_enabled']: + run = setup_wandb(config) + setup_seed(config['seed'] + get_rank()) + device = torch.device('cuda:{}'.format(config['gpu'])) + + if is_main_process() and config['wandb_enabled']: + wandb.watch(model) + + best = float('inf') + best_epoch = 0 + + logger.info('[INFO] Start training...') + start_time_all = time() + for epoch in range(start_epoch, config['epochs']): + if not config['evaluate']: + start_time_epoch = time() + global_step = run_epoch( + model, + train_dataloaders, + optimizer, + epoch, + global_step, + device, + scheduler, + scaler, + config + ) + end_time_epoch = time() + epoch_time = end_time_epoch - start_time_epoch + epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) + logger.info(f'[INFO] Epoch took {epoch_time_str}') + + if not config['debugging']: + with torch.cuda.amp.autocast(enabled=config['fp16']): + val_res = {} + + for medium in val_dataloaders: + res = eval( + model, + val_dataloaders[medium], + device, + epoch, + config + ) + val_res[medium] = res + + + if is_main_process(): + # Average across all datasets + avg_val_res = average_dicts(val_res) + # log to wandb + if config.wandb_enabled: + for medium in val_res: + log_dict_val = {} + # log_dict_val[f'val/{medium}/step'] = epoch + for l in val_res[medium]: + log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l] + wandb.log(log_dict_val) + # for p, v in eval_res.items(): + # log_dict_to_wandb(v, step=global_step, prefix=p) + if config.stop_key is not None and config.stop_key in avg_val_res: + cur_best = avg_val_res[config.stop_key] + else: # stop_key = None + cur_best = best - 1 # save the last as the best + + # Don't save vit and llm weights as they are frozen + state_dict = model_without_ddp.state_dict() + if config.freeze_vit: + state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k} + + if config.freeze_llm: + state_dict = {k:v for k,v in state_dict.items() if 'llm' not in k} + + save_obj = { + "model": state_dict, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "scaler": scaler.state_dict(), + "config": config, + "epoch": epoch, + "global_step": global_step, + } + torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth")) + + if not config.evaluate and cur_best < best: + torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth")) + # eval_file = "eval_res_best.json" + # eval_res.to_json(os.path.join(config.log_dir, eval_file)) + best = cur_best + + if config.evaluate: + break + if config['distributed']: + dist.barrier() + + total_time = time() - start_time_all + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f'[INFO] Training took {total_time_str}') + + if is_main_process() and config['wandb_enabled']: + run.finish() + diff --git a/tasks/stage_3.py b/tasks/stage_3.py new file mode 100644 index 0000000..10b6ee9 --- /dev/null +++ b/tasks/stage_3.py @@ -0,0 +1,1051 @@ +import os +import datetime +import wandb +import torch +import json +import numpy as np +from copy import deepcopy +from time import time +import torch.nn.functional as F + +import torch.distributed as dist +from torch.distributed import ReduceOp +from torch.nn.utils.clip_grad import clip_grad_value_ +from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts +from datasets.utils import get_datasets_media +from datasets.dataloader import MetaLoader +from utils.dist import is_main_process, get_rank, get_world_size +from utils.logger import setup_wandb, log_dict_to_wandb +from .retrieval_utils import evaluation_wrapper +import glog as logger + + +def run_epoch( + model, + train_dataloaders, + # expert_tokenizer, + # enc_dec_tokenizer, + optimizer, + epoch, + global_step, + visdial_step, + avsd_step, + nextqa_step, + device, + scheduler, + scaler, + config +): + model.train() + media_types = list(train_dataloaders.keys()) + + log_freq = config['log_freq'] + # metric_logger = MetricLogger(delimiter=' ') + # metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}')) + # metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}")) + + loss_names = ['loss_' + k for k in config['loss_dict'].keys()] + # for l in loss_names: + # for m in media_types: + # metric_logger.add_meter( + # f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}") + # ) + + # header = '{} | Epoch = {}'.format(config['stage'], epoch) + + model_without_ddp = model + if config['distributed']: + model_without_ddp = model.module + for k in train_dataloaders: + train_dataloaders[k].sampler.set_epoch(epoch) + + # if len(train_dataloaders) == 1: + # train_dataloader = list(train_dataloaders.values())[0] + # else: + train_dataloader = MetaLoader(name2loader=train_dataloaders) + + log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n' + log_text_template += '[Loss] tot = {:.4f} | gen = {:.4f} \n' + log_text_template += '[Other] lr = {:.6f} | iter_time = {:.2f} | eta = {}\n' + + # iterator = metric_logger.log_every(train_dataloader, log_freq, header) + local_step = 0 + # vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list + for media_type, batch in train_dataloader: + vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3] + + start = time() + vis = vis.to(device) + + with torch.cuda.amp.autocast(enabled=config.fp16): + loss_dict = model(vis, caption, history, answer, media_type) + loss = sum(loss_dict.values()) + loss = loss / config['accum_grad_every'] + + scaler.scale(loss).backward() + + # Perfrom gradient clipping: unscale --> clip + if config['clip_grad_value'] > 0: + # scaler.unscale_(optimizer) + clip_grad_value_(model.parameters(), config['clip_grad_value']) + + if local_step % config.accum_grad_every == 0: + scaler.step(optimizer) + scaler.update() + scheduler.step() + optimizer.zero_grad() + + time_iter = time() - start + eta = (len(train_dataloader) - local_step - 1) * time_iter + eta = str(datetime.timedelta(seconds=eta)) + # log + log_dict_visdial = {} + log_dict_avsd = {} + log_dict_nextqa = {} + + log_dict_rest = {} + + for loss_name in loss_names: + value = loss_dict[loss_name] + value = value if isinstance(value, float) else value.item() + # metric_logger.update(**{f"{media_type}/{loss_name}": value}) + if media_type == 'visdial': + log_dict_visdial[f"train/{media_type}/{loss_name}"] = value + elif media_type == 'avsd': + log_dict_avsd[f"train/{media_type}/{loss_name}"] = value + elif media_type == 'nextqa': + log_dict_nextqa[f"train/{media_type}/{loss_name}"] = value + + log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"] + + if is_main_process() and local_step % log_freq == 0 and local_step % config['accum_grad_every'] == 0: + log_dict_rest['train/other/step'] = global_step + + if media_type == 'visdial': + log_dict_visdial['train/visdial/step'] = visdial_step + log_dict = log_dict_visdial + + elif media_type == 'avsd': + log_dict_avsd['train/avsd/step'] = avsd_step + log_dict = log_dict_avsd + + elif media_type == 'nextqa': + log_dict_nextqa['train/nextqa/step'] = nextqa_step + log_dict = log_dict_nextqa + + log_text = log_text_template.format( + epoch, config.epochs-1, local_step, len(train_dataloader) , media_type, loss.item(), + log_dict[f'train/{media_type}/loss_gen'], log_dict_rest['train/other/lr'], + time_iter, eta + ) + logger.info(log_text) + + if config['wandb_enabled']: + wandb.log(log_dict_rest) + wandb.log(log_dict) + + if media_type == 'visdial': + visdial_step += 1 + elif media_type == 'avsd': + avsd_step += 1 + elif media_type == 'nextqa': + nextqa_step += 1 + + + local_step += 1 + global_step += 1 + + return global_step, visdial_step, avsd_step, nextqa_step + + + # if is_main_process() and local_step % config['log_model_outputs_every'] == 0 and config['log_model_outputs']: + # predictions = [] + # labels = [] + # probs = F.softmax(logits, dim=-1) + # preds = torch.topk(probs, 1)[1].squeeze(-1) + # preds = preds.tolist() + # lm_labels_list = label_ids['input_ids'].tolist() + # lm_labels_list = [[s for s in label if s != 1] for label in lm_labels_list] + # # reponses = '' + # # labels = '' + # model_pred_text = '' + # for pred, label in zip(preds, lm_labels_list): + # predictions.append('\n' + 'Pred: ' + tokenizer_enc_dec.decode(pred) + '\n') + # labels.append('\n' + 'GT: ' + tokenizer_enc_dec.decode(label) + '\n') + + # if len(predictions) < 4: + # predictions = predictions[:4] + # labels = labels[:4] + + + # for label, pred in zip(labels, predictions): + # model_pred_text += label + pred + # model_pred_text += "---------------------" + # logger.info(model_pred_text) + + # # output['reponses'] = reponses + # # output['gt'] = labels + + + + +def eval(model, val_dataloader, device, epoch, config): + + model.eval() + + log_text_template = '\n' + '-' * 25 + '\n[Val Epoch {}][Iter. {}/{}][Media-type {}]\n' + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + + log_text_template += '[Losses] gen = {:.4f} \n' + + # cum_loss_stc = 0 + # cum_loss_stm = 0 + # cum_loss_vcc = 0 + # cum_loss_vcm = 0 + # cum_loss_vhc = 0 + # cum_loss_vhm = 0 + # cum_loss_chc = 0 + # cum_loss_chm = 0 + # cum_loss_mlm = 0 + cum_loss_gen = 0 + cum_loss_tot = 0 + val_step = 0 + media_type = val_dataloader.dataset.medium + if is_main_process(): + start_time = time() + + # for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader: + for batch in val_dataloader: + + vis, caption, history, answer = batch[0], batch[1], batch[2], batch[3] + + vis = vis.to(device) + + with torch.cuda.amp.autocast(enabled=config.fp16): + with torch.no_grad(): + # loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type) + # loss_dict, _ = model(vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, media_type) + loss_dict = model(vis, caption, history, answer, media_type) + + # loss_dict = model(vis, cap_ids, hist_ids, ques_ids, label_ids, media_type) + loss = sum(loss_dict.values()) + # loss_stc = loss_dict['loss_stc'] + # loss_stm = loss_dict['loss_stm'] + # loss_vcc = loss_dict['loss_vcc'] + # loss_vcm = loss_dict['loss_vcm'] + # loss_vhc = loss_dict['loss_vhc'] + # loss_vhm = loss_dict['loss_vhm'] + # loss_chc = loss_dict['loss_chc'] + # loss_chm = loss_dict['loss_chm'] + # loss_mlm = loss_dict['loss_mlm'] + loss_gen = loss_dict['loss_gen'] + + if config['distributed']: + dist.all_reduce(loss, op=ReduceOp.AVG) + # if config.loss_dict['stc'] != 0: + # dist.all_reduce(loss_stc, op=ReduceOp.AVG) + # if config.loss_dict['stm'] != 0: + # dist.all_reduce(loss_stm, op=ReduceOp.AVG) + # if config.loss_dict['vcc'] != 0: + # dist.all_reduce(loss_vcc, op=ReduceOp.AVG) + # if config.loss_dict['vcm'] != 0: + # dist.all_reduce(loss_vcm, op=ReduceOp.AVG) + # if config.loss_dict['vhc'] != 0: + # dist.all_reduce(loss_vhc, op=ReduceOp.AVG) + # if config.loss_dict['vhm'] != 0: + # dist.all_reduce(loss_vhm, op=ReduceOp.AVG) + # if config.loss_dict['chc'] != 0: + # dist.all_reduce(loss_chc, op=ReduceOp.AVG) + # if config.loss_dict['chm'] != 0: + # dist.all_reduce(loss_chm, op=ReduceOp.AVG) + # if config.loss_dict['mlm'] != 0: + # dist.all_reduce(loss_mlm, op=ReduceOp.AVG) + if config.loss_dict['gen'] != 0: + dist.all_reduce(loss_gen, op=ReduceOp.AVG) + + if is_main_process(): + cum_loss_tot += loss.item() + # cum_loss_stc += loss_stc.item() + # cum_loss_stm += loss_stm.item() + # cum_loss_vcc += loss_vcc.item() + # cum_loss_vcm += loss_vcm.item() + # cum_loss_vhc += loss_vhc.item() + # cum_loss_vhm += loss_vhm.item() + # cum_loss_chc += loss_chc.item() + # cum_loss_chm += loss_chm.item() + # cum_loss_mlm += loss_mlm.item() + cum_loss_gen += loss_gen.item() + + if val_step % config.log_freq == 0: + # log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n' + # log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n' + # log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n' + log_text = log_text_template.format( + epoch, val_step, len(val_dataloader), media_type, + # loss_vcc, loss_vcm, loss_stc, loss_stm, 0, + # loss_vhc, loss_vhm, loss_chc, loss_chm, + loss_gen + ) + + logger.info(log_text) + # logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format( + # epoch, val_step, len(val_dataloader), gen_loss, loss + # )) + val_step += 1 + + if config['distributed']: + dist.barrier() + + if is_main_process(): + duration = time() - start_time + + cum_loss_tot /= len(val_dataloader) + # cum_loss_stc /= len(val_dataloader) + # cum_loss_stm /= len(val_dataloader) + # cum_loss_vcc /= len(val_dataloader) + # cum_loss_vcm /= len(val_dataloader) + # cum_loss_vhc /= len(val_dataloader) + # cum_loss_vhm /= len(val_dataloader) + # cum_loss_chc /= len(val_dataloader) + # cum_loss_chm /= len(val_dataloader) + # cum_loss_mlm /= len(val_dataloader) + cum_loss_gen /= len(val_dataloader) + + logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format( + datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot + )) + loss_dict = { + # 'stc': cum_loss_stc, + # 'stm': cum_loss_stm, + # 'vcc': cum_loss_vcc, + # 'vcm': cum_loss_vcm, + # 'vhc': cum_loss_vhc, + # 'vhm': cum_loss_vhm, + # 'chc': cum_loss_chc, + # 'chm': cum_loss_chm, + # 'mlm': cum_loss_mlm, + 'gen': cum_loss_gen, + 'tot': cum_loss_tot + } + return loss_dict + + +def ft_avsd( + model, + model_without_ddp, + train_dataloaders, + val_dataloaders, + optimizer, + global_step, + visdial_step, + avsd_step, + nextqa_step, + scheduler, + scaler, + start_epoch, + config +): + if is_main_process() and config['wandb_enabled']: + run = setup_wandb(config) + setup_seed(config['seed'] + get_rank()) + # device = torch.device('cuda:{}'.format(config['gpu'])) + device = config.device + # expert_tokenizer = model_without_ddp.expert_tokenizer + # enc_dec_tokenizer = model_without_ddp.enc_dec_tokenizer + + if is_main_process() and config['wandb_enabled']: + wandb.watch(model) + + best = float('inf') + + logger.info('[INFO] Start training...') + start_time_all = time() + for epoch in range(start_epoch, config['epochs']): + if not config['evaluate']: + if is_main_process(): + start_time_epoch = time() + + global_step, visdial_step, avsd_step, nextqa_step = run_epoch( + model, + train_dataloaders, + # expert_tokenizer, + # enc_dec_tokenizer, + optimizer, + epoch, + global_step, + visdial_step, + avsd_step, + nextqa_step, + device, + scheduler, + scaler, + config + ) + if is_main_process(): + end_time_epoch = time() + epoch_time = end_time_epoch - start_time_epoch + epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) + logger.info(f'[INFO] Epoch took {epoch_time_str}') + + if not config['debugging']: + with torch.cuda.amp.autocast(enabled=config['fp16']): + val_res = {} + + for medium in val_dataloaders: + res = eval( + model, + val_dataloaders[medium], + # expert_tokenizer, + # enc_dec_tokenizer, + device, + epoch, + config + ) + val_res[medium] = res + + if is_main_process(): + # Average across all datasets + avg_val_res = average_dicts(val_res) + # log to wandb + if config.wandb_enabled: + for medium in val_res: + log_dict_val = {} + # log_dict_val[f'val/{medium}/step'] = epoch + for l in val_res[medium]: + log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l] + wandb.log(log_dict_val) + # for p, v in eval_res.items(): + # log_dict_to_wandb(v, step=global_step, prefix=p) + if config.stop_key is not None and config.stop_key in avg_val_res: + cur_best = avg_val_res[config.stop_key] + else: # stop_key = None + cur_best = best - 1 # save the last as the best + + # Don't save vit and llm weights as they are frozen + state_dict = model_without_ddp.state_dict() + if config.freeze_vit: + state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k} + + if config.freeze_llm: + state_dict = {k:v for k,v in state_dict.items() if 'llm' not in k} + + save_obj = { + "model": state_dict, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "scaler": scaler.state_dict(), + "config": config, + "epoch": epoch, + "global_step": global_step, + "visdial_step": visdial_step, + "avsd_step": avsd_step, + "nextqa_step": nextqa_step + } + torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth")) + + if not config.evaluate and cur_best < best: + torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth")) + # eval_file = "eval_res_best.json" + # eval_res.to_json(os.path.join(config.log_dir, eval_file)) + best = cur_best + + if config.evaluate: + break + if config['distributed']: + dist.barrier() + + total_time = time() - start_time_all + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f'[INFO] Training took {total_time_str}') + + if is_main_process() and config['wandb_enabled']: + run.finish() + + +def generate(model, dataloader, tag, config, gen_subset_num=None): + + model.eval() + responses = {} + # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec + device = next(model.parameters()).device # Assumes all model parameters are on the same device + # Generate the repsonse for each round + logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) + with torch.no_grad(): + # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): + for counter, (vis, cap, hist, ans, vis_ids) in enumerate(dataloader): + + start_time = time() + vis = vis.to(device, non_blocking=True) + is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] + + # First get the visual features depending on the media type + with torch.cuda.amp.autocast(enabled=config.fp16): + cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) + hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) + + if config.use_moes: + if config.use_sep_spatial_temp_experts: + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) + else: + vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + # construct the global input tensor --> use place holder for vis features + + if config.use_sep_spatial_temp_experts: + moe_outputs = model.moe_forward( + vis_embed_spatial, vis_spatial_mask, + vis_embed_temporal, vis_temporal_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) + temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + + else: + moe_outputs = model.moe_forward_no_sep_spatial_temporal( + vis_embed, vis_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds']) + + cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) + hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) + else: + cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids)) + hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids)) + vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + + if config.llm_family in ['llama', 'mistral']: + bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id + bos_embeds = model.text_embedding(bos) + bos_mask = cap_mask[:, :1] + + inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + if is_vid: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) + + else: + inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + if config.use_moes: + if not config.drop_vis_features: + if config.use_sep_spatial_temp_experts: + if is_vid: + inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + + decoded_ids = model.llm.generate( + inputs_embeds=inputs_embeds, + do_sample=False, + top_p=config.top_p, + temperature=config.temperature, + num_beams=config.beam_depth, + length_penalty=config.length_penalty, + max_length=config.max_generation_length, + pad_token_id=model.tokenizer.pad_token_id, + eos_token_id=model.tokenizer.eos_token_id, + # use_cache=True + ) + + response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids] + + for vis_id, response in zip(vis_ids, response_batch): + responses[vis_id] = response + + time_elapsed = int(time() - start_time) + print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) + + # Create a file with all responses + with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f: + test_data = json.load(f) + test_dialogs = deepcopy(test_data['dialogs']) + # Filter the predicted dialogs + test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) + + for i, dialog in enumerate(test_dialogs): + vid_id = dialog['image_id'] + gen_response = responses[vid_id] + round_num_to_answer = len(dialog['dialog'])-1 + assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' + dialog['dialog'][round_num_to_answer]['answer'] = gen_response + test_dialogs[i] = dialog + + # Log the file + file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty']) + if gen_subset_num is not None: + file_name += f'-part_{gen_subset_num}' + file_name = f'{tag}_' + file_name + output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json') + with open(output_path, 'w') as f: + json.dump({'dialogs': test_dialogs}, f, indent=4) + logger.info('Results logged to {}'.format(output_path)) + # Switch back to training mode + model.train() + + +def generate_visdial(model, dataloader, tag, config, gen_subset_num=None): + + model.eval() + responses = {} + # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec + device = next(model.parameters()).device # Assumes all model parameters are on the same device + # Generate the repsonse for each round + logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) + with torch.no_grad(): + # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): + for counter, (vis, cap, hist, ans, vis_ids, d_rounds) in enumerate(dataloader): + + start_time = time() + vis = vis.to(device, non_blocking=True) + is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] + + # First get the visual features depending on the media type + with torch.cuda.amp.autocast(enabled=config.fp16): + # construct the global input tensor --> use place holder for vis features + cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) + hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) + + if config.use_moes: + if config.use_sep_spatial_temp_experts: + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) + else: + vis_embed, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + + if config.use_sep_spatial_temp_experts: + moe_outputs = model.moe_forward( + vis_embed_spatial, vis_spatial_mask, + vis_embed_temporal, vis_temporal_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) + temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + else: + moe_outputs = model.moe_forward_no_sep_spatial_temporal( + vis_embed, vis_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + vis_embeds = model.moe_to_llm(moe_outputs['vis_embeds']) + + cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) + hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) + else: + cap_embeds = model.llm_to_moe(model.text_embedding(cap_ids)) + hist_embeds = model.llm_to_moe(model.text_embedding(hist_ids)) + vis_embeds, vis_mask = model.encode_vis_with_seq_spa_temp_att(vis, device, is_vid=is_vid) + + if config.llm_family in ['llama', 'mistral']: + bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id + bos_embeds = model.text_embedding(bos) + bos_mask = cap_mask[:, :1] + + inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + if is_vid: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) + + else: + inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + if config.use_moes: + if not config.drop_vis_features: + if config.use_sep_spatial_temp_experts: + if is_vid: + inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([vis_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_mask, attention_mask], dim=1) + + decoded_ids = model.llm.generate( + inputs_embeds=inputs_embeds, + do_sample=False, + top_p=config.top_p, + temperature=config.temperature, + num_beams=config.beam_depth, + length_penalty=config.length_penalty, + max_length=config.max_generation_length, + pad_token_id=model.tokenizer.pad_token_id, + eos_token_id=model.tokenizer.eos_token_id, + # use_cache=True + ) + + response_batch = [model.tokenizer.decode(decoded_id, skip_special_tokens=True) for decoded_id in decoded_ids] + + for vis_id, d_round, response in zip(vis_ids.tolist(), d_rounds.tolist(), response_batch): + responses[str(vis_id) + '_' + str(d_round)] = response + + time_elapsed = time() - start_time + print('Generating resonse {} / {} -- eta = {} '.format(counter + 1, len(dataloader), str(datetime.timedelta(seconds=time_elapsed * (len(dataloader)-counter))) + )) + + # # Create a file with all responses + # with open(config['anno_avsd_test_dstc_{}'.format(config['dstc'])], 'r') as f: + # test_data = json.load(f) + # test_dialogs = deepcopy(test_data['dialogs']) + # # Filter the predicted dialogs + # test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) + + # for i, dialog in enumerate(test_dialogs): + # vid_id = dialog['image_id'] + # gen_response = responses[vid_id] + # round_num_to_answer = len(dialog['dialog'])-1 + # assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' + # dialog['dialog'][round_num_to_answer]['answer'] = gen_response + # test_dialogs[i] = dialog + + # Log the file + file_name = '{}_results_dstc{}_beam_depth_{}_lenPen_{}'.format(config['llm_name'].replace('/', '-'), config['dstc'], config['beam_depth'], config['length_penalty']) + if gen_subset_num is not None: + file_name += f'-part_{gen_subset_num}' + file_name = f'{tag}_' + file_name + output_path = os.path.join(config['output_dir_visdial'], file_name + '.json') + with open(output_path, 'w') as f: + json.dump(responses, f, indent=4) + logger.info('Results logged to {}'.format(output_path)) + # Switch back to training mode + model.train() + +def generate_nextqa(model, dataloader, tag, config, gen_subset_num=None): + + model.eval() + responses = {} + # tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec + device = next(model.parameters()).device # Assumes all model parameters are on the same device + # Generate the repsonse for each round + logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) + with torch.no_grad(): + # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): + for counter, (vis, cap, hist, _, vid_ids, qid) in enumerate(dataloader): + + start_time = time() + vis = vis.to(device, non_blocking=True) + is_vid = config.media_test in ['webvid', 'champagne', 'avsd', 'nextqa'] + + vid_id = vid_ids[0] + qid = qid[0] + if vid_id not in responses: + responses[vid_id] = {} + + # First get the visual features depending on the media type + with torch.cuda.amp.autocast(enabled=config.fp16): + vis_embed_spatial, vis_spatial_mask, vis_embed_temporal, vis_temporal_mask = model.encode_vis(vis, device, is_vid=is_vid) + + # construct the global input tensor --> use place holder for vis features + cap_ids, cap_mask = model.tokenize_text(cap, device, max_len=None) + hist_ids, hist_mask = model.tokenize_text(hist, device, max_len=None) + + moe_outputs = model.moe_forward( + vis_embed_spatial, vis_spatial_mask, + vis_embed_temporal, vis_temporal_mask, + cap_ids, cap_mask, + hist_ids, hist_mask, + is_vid, device + ) + spatial_embeds = model.moe_to_llm(moe_outputs['spatial_embeds']) + temporal_embeds = model.moe_to_llm(moe_outputs['temporal_embeds']) if is_vid else None + cap_embeds = model.moe_to_llm(moe_outputs['cap_embeds']) + hist_embeds = model.moe_to_llm(moe_outputs['hist_embeds']) + + if config.llm_family in ['llama', 'mistral']: + bos = torch.ones_like(cap_ids[:, :1]) * model.tokenizer.bos_token_id + bos_embeds = model.text_embedding(bos) + bos_mask = cap_mask[:, :1] + + inputs_embeds, attention_mask = model.pad_to_right_dec_only_gen_mode(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + if is_vid: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([bos_embeds, spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_mask, vis_spatial_mask, attention_mask], dim=1) + + else: + inputs_embeds, attention_mask = model.pad_to_right_enc_dec(cap_embeds, cap_mask, hist_embeds, hist_mask, device) + + if is_vid: + inputs_embeds = torch.cat([spatial_embeds, temporal_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, vis_temporal_mask, attention_mask], dim=1) + else: + inputs_embeds = torch.cat([spatial_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([vis_spatial_mask, attention_mask], dim=1) + + decoded_ids = model.llm.generate( + inputs_embeds=inputs_embeds, + do_sample=False, + top_p=config.top_p, + temperature=config.temperature, + num_beams=config.beam_depth, + length_penalty=config.length_penalty, + max_length=config.max_generation_length, + pad_token_id=model.tokenizer.pad_token_id, + eos_token_id=model.tokenizer.eos_token_id, + # use_cache=True + ) + + response = model.tokenizer.decode(decoded_ids[0], skip_special_tokens=True) + responses[vid_id][qid] = response + + # for vis_id, response in zip(vis_ids, response_batch): + # responses[vis_id] = response + + time_elapsed = int(time() - start_time) + print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) + + # Create a file with all responses + file_name = 'results_nextqa_beam_depth_{}'.format(config['beam_depth']) + if gen_subset_num is not None: + file_name += f'-part_{gen_subset_num}' + file_name = f'{tag}_' + file_name + output_path = os.path.join(config['output_dir_nextqa'], file_name + '.json') + with open(output_path, 'w') as f: + json.dump(responses, f, indent=4) + print('Results logged to {}'.format(output_path)) + print(os.getcwd()) + # Switch back to training mode + model.train() + + +def generate_enc_dec(model, dataloader, tag, config, gen_subset_num=None): + + model.eval() + responses = {} + tokenizer_enc_dec = dataloader.dataset.tokenizer_enc_dec + device = next(model.parameters()).device # Assumes all model parameters are on the same device + # Generate the repsonse for each round + logger.info('[INFO] Generating responses for {} samples'.format(len(dataloader))) + with torch.no_grad(): + # for counter, (vis, cap_ids, hist_ids, ques_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): + for counter, (vis, cap_ids, hist_ids, _, enc_dec_input_ids, _, vid_id) in enumerate(dataloader): + + start_time = time() + vis = vis.to(device, non_blocking=True) + + for k in cap_ids: + if isinstance(cap_ids[k], torch.Tensor): + cap_ids[k] = cap_ids[k].to(device) + + for k in hist_ids: + if isinstance(hist_ids[k], torch.Tensor): + hist_ids[k] = hist_ids[k].to(device) + + # for k in ques_ids: + # if isinstance(ques_ids[k], torch.Tensor): + # ques_ids[k] = ques_ids[k].to(device) + + for k in enc_dec_input_ids: + if isinstance(enc_dec_input_ids[k], torch.Tensor): + enc_dec_input_ids[k] = enc_dec_input_ids[k].to(device) + + # response = beam_search_generation( + # model, vis, cap_ids, hist_ids, ques_ids, enc_dec_input_ids, tokenizer_enc_dec, config + # ) + + response = beam_search_generation( + model, vis, cap_ids, hist_ids, enc_dec_input_ids, tokenizer_enc_dec, config + ) + + # Decode the response + response = tokenizer_enc_dec.decode(response) + responses[vid_id[0]] = response + # all_graphs[vid] = graphs + time_elapsed = int(time() - start_time) + print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataloader), time_elapsed)) + + # Create a file with all responses + with open(config['anno_avsd_test_{}'.format(config['dstc'])], 'r') as f: + test_data = json.load(f) + test_dialogs = deepcopy(test_data['dialogs']) + # Filter the predicted dialogs + test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs)) + + for i, dialog in enumerate(test_dialogs): + vid_id = dialog['image_id'] + gen_response = responses[vid_id] + round_num_to_answer = len(dialog['dialog'])-1 + assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__' + dialog['dialog'][round_num_to_answer]['answer'] = gen_response + test_dialogs[i] = dialog + + # Log the file + file_name = 'results_dstc{}_beam_depth_{}'.format(config['dstc'], config['beam_depth']) + if gen_subset_num is not None: + file_name += f'-part_{gen_subset_num}' + file_name = f'{tag}_' + file_name + output_path = os.path.join(config['output_dir_avsd_{}'.format(config['dstc'])], file_name + '.json') + with open(output_path, 'w') as f: + json.dump({'dialogs': test_dialogs}, f, indent=4) + logger.info('Results logged to {}'.format(output_path)) + # Switch back to training mode + model.train() + + +def beam_search_generation_decoder_only(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config): + + # gen_ans = [bos_token] + hyplist = [([], 0.0, [])] + best_state = None + comp_hyplist = [] + + # drop_caption = self.config['dstc'] == 10 + # instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption) + + encoder_outputs = None + + for i in range(config['max_generation_length']): + new_hyplist = [] + argmin = 0 + for out, lp, st in hyplist: + decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0) + + # output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') + output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') + + if encoder_outputs is None: + encoder_outputs = output.encoder_outputs + + logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token + logp = F.log_softmax(logits, dim=0) + lp_vec = logp.cpu().data.numpy() + lp + if i >= config['min_generation_length']: + new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1) + comp_hyplist.append((out, new_lp)) + if best_state is None or best_state < new_lp: + best_state = new_lp + count = 1 + for o in np.argsort(lp_vec)[::-1]: # reverse the order + if o in [eos_token, unk_token]: + continue + new_lp = lp_vec[o] + if len(new_hyplist) == config['beam_depth']: + if new_hyplist[argmin][1] < new_lp: + new_st = deepcopy(st) + new_st.append(int(o)) + new_hyplist[argmin] = (out + [o], new_lp, new_st) + argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] + else: + break + else: + new_st = deepcopy(st) + new_st.append(int(o)) + new_hyplist.append((out + [o], new_lp, new_st)) + if len(new_hyplist) == config['beam_depth']: + argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] + count += 1 + hyplist = new_hyplist + + if len(comp_hyplist) > 0: + maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1] + res = maxhyps[0][0] + if res[0] == bos_token: + res = res[1:] + if res[-1] == eos_token: + res = res[:-1] + return res + else: + return [] + + +# def beam_search_generation(model, vis, caption, history, ques, enc_dec_input, tokenizer_enc_dec, config): +def beam_search_generation(model, vis, caption, history, enc_dec_input, tokenizer_enc_dec, config): + + if config['enc_dec_family'] == 'flan_t5': + bos_token = tokenizer_enc_dec.pad_token_id + eos_token = tokenizer_enc_dec.eos_token_id + else: + bos_token = tokenizer_enc_dec.bos_token_id + eos_token = tokenizer_enc_dec.eos_token_id + + unk_token = tokenizer_enc_dec.unk_token_id + + # gen_ans = [bos_token] + hyplist = [([], 0.0, [bos_token])] + best_state = None + comp_hyplist = [] + + # drop_caption = self.config['dstc'] == 10 + # instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption) + + encoder_outputs = None + + for i in range(config['max_generation_length']): + new_hyplist = [] + argmin = 0 + for out, lp, st in hyplist: + decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0) + + # output = model.generate(vis, caption, history, ques, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') + output = model.generate(vis, caption, history, decoder_input_ids, enc_dec_input, encoder_outputs, 'avsd') + + if encoder_outputs is None: + encoder_outputs = output.encoder_outputs + + logits = output['logits'][:,-1,:].squeeze() # get the logits of the last token + logp = F.log_softmax(logits, dim=0) + lp_vec = logp.cpu().data.numpy() + lp + if i >= config['min_generation_length']: + new_lp = lp_vec[eos_token] + config['length_penalty'] * (len(out) + 1) + comp_hyplist.append((out, new_lp)) + if best_state is None or best_state < new_lp: + best_state = new_lp + count = 1 + for o in np.argsort(lp_vec)[::-1]: # reverse the order + if o in [eos_token, unk_token]: + continue + new_lp = lp_vec[o] + if len(new_hyplist) == config['beam_depth']: + if new_hyplist[argmin][1] < new_lp: + new_st = deepcopy(st) + new_st.append(int(o)) + new_hyplist[argmin] = (out + [o], new_lp, new_st) + argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] + else: + break + else: + new_st = deepcopy(st) + new_st.append(int(o)) + new_hyplist.append((out + [o], new_lp, new_st)) + if len(new_hyplist) == config['beam_depth']: + argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0] + count += 1 + hyplist = new_hyplist + + if len(comp_hyplist) > 0: + maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1] + res = maxhyps[0][0] + if res[0] == bos_token: + res = res[1:] + if res[-1] == eos_token: + res = res[:-1] + return res + else: + return [] \ No newline at end of file diff --git a/tokenizers/flan_t5/added_tokens.json b/tokenizers/flan_t5/added_tokens.json new file mode 100644 index 0000000..af404f1 --- /dev/null +++ b/tokenizers/flan_t5/added_tokens.json @@ -0,0 +1,109 @@ +{ + "": 32105, + "": 32099, + "": 32089, + "": 32088, + "": 32087, + "": 32086, + "": 32085, + "": 32084, + "": 32083, + "": 32082, + "": 32081, + "": 32080, + "": 32098, + "": 32079, + "": 32078, + "": 32077, + "": 32076, + "": 32075, + "": 32074, + "": 32073, + "": 32072, + "": 32071, + "": 32070, + "": 32097, + "": 32069, + "": 32068, + "": 32067, + "": 32066, + "": 32065, + "": 32064, + "": 32063, + "": 32062, + "": 32061, + "": 32060, + "": 32096, + "": 32059, + "": 32058, + "": 32057, + "": 32056, + "": 32055, + "": 32054, + "": 32053, + "": 32052, + "": 32051, + "": 32050, + "": 32095, + "": 32049, + "": 32048, + "": 32047, + "": 32046, + "": 32045, + "": 32044, + "": 32043, + "": 32042, + "": 32041, + "": 32040, + "": 32094, + "": 32039, + "": 32038, + "": 32037, + "": 32036, + "": 32035, + "": 32034, + "": 32033, + "": 32032, + "": 32031, + "": 32030, + "": 32093, + "": 32029, + "": 32028, + "": 32027, + "": 32026, + "": 32025, + "": 32024, + "": 32023, + "": 32022, + "": 32021, + "": 32020, + "": 32092, + "": 32019, + "": 32018, + "": 32017, + "": 32016, + "": 32015, + "": 32014, + "": 32013, + "": 32012, + "": 32011, + "": 32010, + "": 32091, + "": 32009, + "": 32008, + "": 32007, + "": 32006, + "": 32005, + "": 32004, + "": 32003, + "": 32002, + "": 32001, + "": 32000, + "": 32090, + "": 32106, + "": 32101, + "": 32100, + "": 32103, + "": 32104, + "": 32102 +} diff --git a/tokenizers/flan_t5/special_tokens_map.json b/tokenizers/flan_t5/special_tokens_map.json new file mode 100644 index 0000000..2e4806a --- /dev/null +++ b/tokenizers/flan_t5/special_tokens_map.json @@ -0,0 +1,74 @@ +{ + "additional_special_tokens": [ + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tokenizers/flan_t5/spiece.model b/tokenizers/flan_t5/spiece.model new file mode 100644 index 0000000..4e28ff6 Binary files /dev/null and b/tokenizers/flan_t5/spiece.model differ diff --git a/tokenizers/flan_t5/tokenizer_config.json b/tokenizers/flan_t5/tokenizer_config.json new file mode 100644 index 0000000..ef46cff --- /dev/null +++ b/tokenizers/flan_t5/tokenizer_config.json @@ -0,0 +1,902 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32007": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32008": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32009": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32010": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32011": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32012": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32013": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32014": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32015": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32016": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32017": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32018": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32019": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32020": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32021": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32022": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32023": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32024": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32025": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32026": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32027": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32028": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32029": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32030": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32031": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32032": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32033": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32034": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32035": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32036": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32037": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32038": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32039": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32040": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32041": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32042": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32043": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32044": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32045": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32046": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32047": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32048": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32049": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32050": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32051": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32052": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32053": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32054": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32055": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32056": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32057": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32058": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32059": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32060": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32061": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32062": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32063": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32064": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32065": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32066": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32067": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32068": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32069": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32070": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32071": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32072": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32073": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32074": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32075": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32076": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32077": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32078": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32079": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32080": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32081": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32082": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32083": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32084": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32085": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32086": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32087": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32088": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32089": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32090": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32091": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32092": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32093": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32094": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32095": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32096": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32097": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32098": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32099": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32100": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32101": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32102": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32103": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32104": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32105": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32106": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "" + ], + "bos_token": "", + "clean_up_tokenization_spaces": true, + "eos_token": "", + "extra_ids": 100, + "legacy": true, + "mask_token": "", + "model_max_length": 512, + "pad_token": "", + "sp_model_kwargs": {}, + "tokenizer_class": "T5Tokenizer", + "unk_token": "" +} diff --git a/tokenizers/llama/added_tokens.json b/tokenizers/llama/added_tokens.json new file mode 100644 index 0000000..4a9d3f4 --- /dev/null +++ b/tokenizers/llama/added_tokens.json @@ -0,0 +1,9 @@ +{ + "": 32005, + "": 32006, + "": 32001, + "": 32000, + "": 32003, + "": 32004, + "": 32002 +} diff --git a/tokenizers/llama/special_tokens_map.json b/tokenizers/llama/special_tokens_map.json new file mode 100644 index 0000000..2e4806a --- /dev/null +++ b/tokenizers/llama/special_tokens_map.json @@ -0,0 +1,74 @@ +{ + "additional_special_tokens": [ + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tokenizers/llama/tokenizer.model b/tokenizers/llama/tokenizer.model new file mode 100644 index 0000000..22bccbc Binary files /dev/null and b/tokenizers/llama/tokenizer.model differ diff --git a/tokenizers/llama/tokenizer_config.json b/tokenizers/llama/tokenizer_config.json new file mode 100644 index 0000000..5410f5c --- /dev/null +++ b/tokenizers/llama/tokenizer_config.json @@ -0,0 +1,107 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "" + ], + "bos_token": "", + "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": false, + "mask_token": "", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "padding_side": "right", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +} diff --git a/tokenizers/mistral/added_tokens.json b/tokenizers/mistral/added_tokens.json new file mode 100644 index 0000000..4a9d3f4 --- /dev/null +++ b/tokenizers/mistral/added_tokens.json @@ -0,0 +1,9 @@ +{ + "": 32005, + "": 32006, + "": 32001, + "": 32000, + "": 32003, + "": 32004, + "": 32002 +} diff --git a/tokenizers/mistral/special_tokens_map.json b/tokenizers/mistral/special_tokens_map.json new file mode 100644 index 0000000..2e4806a --- /dev/null +++ b/tokenizers/mistral/special_tokens_map.json @@ -0,0 +1,74 @@ +{ + "additional_special_tokens": [ + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tokenizers/mistral/tokenizer.model b/tokenizers/mistral/tokenizer.model new file mode 100644 index 0000000..85c0803 Binary files /dev/null and b/tokenizers/mistral/tokenizer.model differ diff --git a/tokenizers/mistral/tokenizer_config.json b/tokenizers/mistral/tokenizer_config.json new file mode 100644 index 0000000..2b3ca46 --- /dev/null +++ b/tokenizers/mistral/tokenizer_config.json @@ -0,0 +1,106 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "" + ], + "bos_token": "", + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": true, + "mask_token": "", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +} diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/basic.py b/utils/basic.py new file mode 100644 index 0000000..b482c2a --- /dev/null +++ b/utils/basic.py @@ -0,0 +1,309 @@ +import numpy as np +import io +import os +import json +import logging +import random +import time +from collections import defaultdict, deque +import datetime +from pathlib import Path +from typing import List, Union +import itertools + +import torch +import torch.distributed as dist +from .dist import is_dist_avail_and_initialized + + +logger = logging.getLogger(__name__) + + +def average_dicts(dicts): + # media = list(dicts.keys()) + # keys = [list(d.keys()) for d in dicts.values] + # keys = list(itertools.chain.from_iterable(keys)) + # keys = list(set(keys)) + res = {} + counter = {} + for medium, medium_dict in dicts.items(): + for loss_key, loss_value in medium_dict.items(): + if loss_key not in res: + res[loss_key] = loss_value + counter[loss_key] = 1 + else: + res[loss_key] += loss_value + counter[loss_key] += 1 + for k in res: + res[k] = res[k] / counter[k] + return res + + + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], + dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + if meter.count == 0: # skip empty meter + loss_str.append( + "{}: {}".format(name, "No data") + ) + else: + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + if meter.count == 0: + loss_str.append( + "{}: {}".format(name, "No data") + ) + else: + loss_str.append( + "{}: {:.4f}".format(name, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + def get_global_avg_dict(self, prefix=""): + """include a separator (e.g., `/`, or "_") at the end of `prefix`""" + d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()} + return d + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, log_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}\n', + '{meters}\n', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % log_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + res_mem=torch.cuda.max_memory_reserved() / MB, + )) + else: + logger.info(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def compute_acc(logits, label, reduction='mean'): + ret = (torch.argmax(logits, dim=1) == label).float() + if reduction == 'none': + return ret.detach() + elif reduction == 'mean': + return ret.mean().item() + + +def compute_n_params(model, return_str=True): + tot = 0 + for p in model.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return '{:.1f}M'.format(tot / 1e6) + else: + return '{:.1f}K'.format(tot / 1e3) + else: + return tot + + +def setup_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def remove_files_if_exist(file_paths): + for fp in file_paths: + if os.path.isfile(fp): + os.remove(fp) + + +def save_json(data, filename, save_pretty=False, sort_keys=False): + with open(filename, "w") as f: + if save_pretty: + f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) + else: + json.dump(data, f) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +def flat_list_of_lists(l): + """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" + return [item for sublist in l for item in sublist] + + +def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]): + """ + Args: + root: path to the directory to start search files + suffix: any str as suffix, or can match multiple such strings + when input is List[str]. + Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`] + Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`. + """ + if isinstance(suffix, str): + suffix = [suffix, ] + filepaths = flat_list_of_lists( + [list(Path(root).rglob(f"*{e}")) for e in suffix]) + return filepaths + + +def match_key_and_shape(state_dict1, state_dict2): + keys1 = set(state_dict1.keys()) + keys2 = set(state_dict2.keys()) + print(f"keys1 - keys2: {keys1 - keys2}") + print(f"keys2 - keys1: {keys2 - keys1}") + + mismatch = 0 + for k in list(keys1): + if state_dict1[k].shape != state_dict2[k].shape: + print( + f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}") + mismatch += 1 + print(f"mismatch {mismatch}") + + +def merge_dicts(list_dicts): + merged_dict = list_dicts[0].copy() + for i in range(1, len(list_dicts)): + merged_dict.update(list_dicts[i]) + return merged_dict diff --git a/utils/dist.py b/utils/dist.py new file mode 100644 index 0000000..f8d92e2 --- /dev/null +++ b/utils/dist.py @@ -0,0 +1,25 @@ +import torch.distributed as dist + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 \ No newline at end of file diff --git a/utils/easydict.py b/utils/easydict.py new file mode 100644 index 0000000..241aca4 --- /dev/null +++ b/utils/easydict.py @@ -0,0 +1,149 @@ +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> map(attrgetter('x'), d.bar) + [1, 3] + >>> map(attrgetter('y'), d.bar) + [2, 4] + >>> d = EasyDict() + >>> d.keys() + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> o.items() + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + if hasattr(self, k): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + diff --git a/utils/init.py b/utils/init.py new file mode 100644 index 0000000..cad40fa --- /dev/null +++ b/utils/init.py @@ -0,0 +1,154 @@ +import os +import torch +import random +import pyhocon +import datetime +import json +import subprocess +import itertools +import glob +import glog as log +import sys +import re +from os import path as osp +import numpy as np + + +# def load_runner(config, tokenizer, vocab_size): +# if config['task'] == 'avsd': +# return AVSDRunner(config, tokenizer, vocab_size) +# if config['task'] == 'simmc': +# return SIMMCRunner(config, tokenizer, vocab_size) +# elif config['task'] == 'nextqa': +# return NEXTQARunner(config, tokenizer, vocab_size) +# else: +# raise ValueError + + + + +def set_random_seed(random_seed): + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + random.seed(random_seed) + np.random.seed(random_seed) + + +def copy_file_to_log(log_dir): + dirs_to_cp = ['.', 'config', 'datasets', 'runners', 'models'] + files_to_cp = ['*.py', '*.json', '*.sh', '*.conf'] + for dir_name in dirs_to_cp: + dir_name = osp.join(log_dir, 'code', dir_name) + if not osp.exists(dir_name): + os.makedirs(dir_name) + for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp): + filename = osp.join(dir_name, file_name) + if len(glob.glob(filename)) > 0: + os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}') + log.info(f'Files copied to {osp.join(log_dir, "code")}') + + +def set_log_file(fname, file_only=False): + # if fname already exists, find all log file under log dir, + # and name the current log file with a new number + if osp.exists(fname): + prefix, suffix = osp.splitext(fname) + log_files = glob.glob(prefix + '*' + suffix) + count = 0 + for log_file in log_files: + num = re.search(r'(\d+)', log_file) + if num is not None: + num = int(num.group(0)) + count = max(num, count) + fname = fname.replace(suffix, str(count + 1) + suffix) + # set log file + # simple tricks for duplicating logging destination in the logging module such as: + # logging.getLogger().addHandler(logging.FileHandler(filename)) + # does NOT work well here, because python Traceback message (not via logging module) is not sent to the file, + # the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit + # complicated but simulates exactly the "tee" command in linux shell, and it redirects everything + if file_only: + # we only output messages to file, and stdout/stderr receives nothing. + # this feature is designed for executing the script via ssh: + # since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a + # ssh channel and its buffer fills up, the execution machine will not be able to write anything into the + # channel and the process will be set to sleeping (S) status until someone reads all data from the channel. + # this is not desired since we do not want to read stdout/stderr from the controller machine. + # so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file. + log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1) + else: + # we output messages to both file and stdout/stderr + tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE) + os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) + os.dup2(tee.stdin.fileno(), sys.stderr.fileno()) + + +def set_training_steps(config, num_samples): + if config['parallel'] and config['dp_type'] == 'dp': + config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size'])) + else: + config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus']))) + if 'train_steps' not in config: + config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs'] + if 'warmup_steps' not in config: + config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio']) + return config + + +def initialize_from_env(model, mode, stage, eval_dir, tag=''): + + if mode in ['train', 'debug']: + path_config = f"config/{model}_{stage}.conf" + config = pyhocon.ConfigFactory.parse_file(path_config)[stage] + else: + path_config = os.path.join(eval_dir, f'{model}_{stage}.conf') + config = pyhocon.ConfigFactory.parse_file(path_config)[stage] + config['log_dir'] = eval_dir + + if "CUDA_VISIBLE_DEVICES" in os.environ: + config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(',')) + # multi-gpu setting + if config['num_gpus'] > 1: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ["MASTER_PORT"] = str(config['master_port']) + else: + config['num_gpus'] = 1 + + model += '-' + config.llm_name.replace('/', '_') + + if mode == 'debug': + model += '_debug' + + if tag: + model += '-' + tag + if mode != 'generate': + config["log_dir"] = os.path.join(config["log_dir"], model) + if not os.path.exists(config["log_dir"]): + os.makedirs(config["log_dir"]) + # copy the config file + os.system(f'cp {path_config} {config["log_dir"]}') + + config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S') + + config['expert_config'] = config['bert_config_{}'.format(config['expert_size'])] + config['expert_config_json'] = json.load(open(config['expert_config'], 'r')) + + config['beit_config_json'] = json.load(open(config['beit_config'], 'r')) + + + config['model'] = model + config['stage'] = stage + config['loss_dict'] = {k:v for k,v in zip(config['loss_names'], config['loss_weights'])} + + return config + + +def set_training_steps(config, num_samples, batch_sizes): + config['num_iter_per_epoch'] = sum([int(np.ceil(num_sample / (bs * config['accum_grad_every'] * config['num_gpus']))) for num_sample, bs in zip(num_samples, batch_sizes)]) + if 'num_training_steps' not in config: + config['num_training_steps'] = config['num_iter_per_epoch'] * config['epochs'] + if 'num_warmup_steps' not in config: + config['num_warmup_steps'] = int(config['num_iter_per_epoch'] * config.get('warmup_epochs', 1.0)) + + # config['num_warmup_steps'] = int(config['num_training_steps'] * config['warmup_ratio']) + return config \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..544f559 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,286 @@ +# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import logging +import os +import sys +import time +import wandb + +from .dist import get_rank, is_main_process +from termcolor import colored + + +def log_dict_to_wandb(log_dict, step, prefix=""): + """include a separator `/` at the end of `prefix`""" + if not is_main_process(): + return + + log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} + wandb.log(log_dict, step) + + +def setup_wandb(config): + if not (config.wandb_enabled and is_main_process()): + return + + run = wandb.init( + config=config, + project=config.wandb_project, + # entity=config.wandb.entity, + mode=config.wandb_mode, + # name=os.path.basename(config.output_dir), + reinit=True + ) + wandb.define_metric('train/webvid/step') + wandb.define_metric('train/webvid/*', 'train/webvid/step') + + wandb.define_metric('train/cc3m/step') + wandb.define_metric('train/cc3m/*', 'train/cc3m/step') + + wandb.define_metric('train/other/step') + wandb.define_metric('train/other/*', 'train/other/step') + + wandb.define_metric('val/msrvtt/step') + wandb.define_metric('val/msrvtt/*', 'val/msrvtt/step') + + wandb.define_metric('train/champagne/step') + wandb.define_metric('train/champagne/*', 'train/champagne/step') + + wandb.define_metric('train/visdial/step') + wandb.define_metric('train/visdial/*', 'train/visdial/step') + + wandb.define_metric('train/avsd/step') + wandb.define_metric('train/avsd/*', 'train/avsd/step') + + wandb.define_metric('train/nextqa/step') + wandb.define_metric('train/nextqa/*', 'train/nextqa/step') + + return run + + +def setup_output_folder(save_dir: str, folder_only: bool = False): + """Sets up and returns the output file where the logs will be placed + based on the configuration passed. Usually "save_dir/logs/log_.txt". + If env.log_dir is passed, logs will be directly saved in this folder. + Args: + folder_only (bool, optional): If folder should be returned and not the file. + Defaults to False. + Returns: + str: folder or file path depending on folder_only flag + """ + log_filename = "train_" + log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") + log_filename += ".log" + + log_folder = os.path.join(save_dir, "logs") + + if not os.path.exists(log_folder): + os.path.mkdirs(log_folder) + + if folder_only: + return log_folder + + log_filename = os.path.join(log_folder, log_filename) + + return log_filename + + +def setup_logger( + output: str = None, + color: bool = True, + name: str = "mmf", + disable: bool = False, + clear_handlers=True, + *args, + **kwargs, +): + """ + Initialize the MMF logger and set its verbosity level to "INFO". + Outside libraries shouldn't call this in case they have set there + own logging handlers and setup. If they do, and don't want to + clear handlers, pass clear_handlers options. + The initial version of this function was taken from D2 and adapted + for MMF. + Args: + output (str): a file name or a directory to save log. + If ends with ".txt" or ".log", assumed to be a file name. + Default: Saved to file + color (bool): If false, won't log colored logs. Default: true + name (str): the root module name of this logger. Defaults to "mmf". + disable: do not use + clear_handlers (bool): If false, won't clear existing handlers. + Returns: + logging.Logger: a logger + """ + if disable: + return None + logger = logging.getLogger(name) + logger.propagate = False + + logging.captureWarnings(True) + warnings_logger = logging.getLogger("py.warnings") + + plain_formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(name)s : %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + + distributed_rank = get_rank() + handlers = [] + + logging_level = logging.INFO + # logging_level = logging.DEBUG + + if distributed_rank == 0: + logger.setLevel(logging_level) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging_level) + if color: + formatter = ColorfulFormatter( + colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + warnings_logger.addHandler(ch) + handlers.append(ch) + + # file logging: all workers + if output is None: + output = setup_output_folder() + + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "train.log") + if distributed_rank > 0: + filename = filename + f".rank{distributed_rank}" + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging_level) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + warnings_logger.addHandler(fh) + handlers.append(fh) + + # Slurm/FB output, only log the main process + # save_dir = get_mmf_env(key="save_dir") + if "train.log" not in filename and distributed_rank == 0: + filename = os.path.join(output, "train.log") + sh = logging.StreamHandler(_cached_log_stream(filename)) + sh.setLevel(logging_level) + sh.setFormatter(plain_formatter) + logger.addHandler(sh) + warnings_logger.addHandler(sh) + handlers.append(sh) + + logger.info(f"Logging to: {filename}") + + # Remove existing handlers to add MMF specific handlers + if clear_handlers: + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + # Now, add our handlers. + logging.basicConfig(level=logging_level, handlers=handlers) + + return logger + + +def setup_very_basic_config(color=True): + plain_formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(name)s : %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.INFO) + if color: + formatter = ColorfulFormatter( + colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + # Setup a minimal configuration for logging in case something tries to + # log a message even before logging is setup by MMF. + logging.basicConfig(level=logging.INFO, handlers=[ch]) + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return open(filename, "a") + + +# ColorfulFormatter is adopted from Detectron2 and adapted for MMF +class ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super().formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +class TensorboardLogger: + def __init__(self, log_folder="./logs", iteration=0): + # This would handle warning of missing tensorboard + from torch.utils.tensorboard import SummaryWriter + + self.summary_writer = None + self._is_master = is_main_process() + # self.timer = Timer() + self.log_folder = log_folder + + if self._is_master: + # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) + current_time = time.strftime("%Y-%m-%dT%H:%M:%S") + # self.timer.get_time_hhmmss(None, format=self.time_format) + tensorboard_folder = os.path.join( + self.log_folder, f"tensorboard_{current_time}" + ) + self.summary_writer = SummaryWriter(tensorboard_folder) + + def __del__(self): + if getattr(self, "summary_writer", None) is not None: + self.summary_writer.close() + + def _should_log_tensorboard(self): + if self.summary_writer is None or not self._is_master: + return False + else: + return True + + def add_scalar(self, key, value, iteration): + if not self._should_log_tensorboard(): + return + + self.summary_writer.add_scalar(key, value, iteration) + + def add_scalars(self, scalar_dict, iteration): + if not self._should_log_tensorboard(): + return + + for key, val in scalar_dict.items(): + self.summary_writer.add_scalar(key, val, iteration) + + def add_histogram_for_model(self, model, iteration): + if not self._should_log_tensorboard(): + return + + for name, param in model.named_parameters(): + np_param = param.clone().cpu().data.numpy() + self.summary_writer.add_histogram(name, np_param, iteration) diff --git a/utils/metrcis.py b/utils/metrcis.py new file mode 100644 index 0000000..8f4c132 --- /dev/null +++ b/utils/metrcis.py @@ -0,0 +1,174 @@ +""" +A Metric observes output of certain model, for example, in form of logits or +scores, and accumulates a particular metric with reference to some provided +targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean +Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). + +Each ``Metric`` must atleast implement three methods: + - ``observe``, update accumulated metric with currently observed outputs + and targets. + - ``retrieve`` to return the accumulated metric., an optionally reset + internally accumulated metric (this is commonly done between two epochs + after validation). + - ``reset`` to explicitly reset the internally accumulated metric. + +Caveat, if you wish to implement your own class of Metric, make sure you call +``detach`` on output tensors (like logits), else it will cause memory leaks. +""" +import torch + + +def scores_to_ranks(scores: torch.Tensor): + """Convert model output scores into ranks.""" + batch_size, num_rounds, num_options = scores.size() + scores = scores.view(-1, num_options) + + # sort in descending order - largest score gets highest rank + sorted_ranks, ranked_idx = scores.sort(1, descending=True) + + # i-th position in ranked_idx specifies which score shall take this + # position but we want i-th position to have rank of score at that + # position, do this conversion + ranks = ranked_idx.clone().fill_(0) + for i in range(ranked_idx.size(0)): + for j in range(num_options): + ranks[i][ranked_idx[i][j]] = j + # convert from 0-99 ranks to 1-100 ranks + ranks += 1 + ranks = ranks.view(batch_size, num_rounds, num_options) + return ranks + + +class SparseGTMetrics(object): + """ + A class to accumulate all metrics with sparse ground truth annotations. + These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. + """ + + def __init__(self): + self._rank_list = [] + + def observe( + self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor + ): + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, num_rounds, num_options) + predicted_ranks = scores_to_ranks(predicted_scores) + batch_size, num_rounds, num_options = predicted_ranks.size() + + # collapse batch dimension + predicted_ranks = predicted_ranks.view( + batch_size * num_rounds, num_options + ) + + # shape: (batch_size * num_rounds, ) + target_ranks = target_ranks.view(batch_size * num_rounds).long() + + # shape: (batch_size * num_rounds, ) + predicted_gt_ranks = predicted_ranks[ + torch.arange(batch_size * num_rounds), target_ranks + ] + self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) + + def retrieve(self, reset: bool = True): + num_examples = len(self._rank_list) + if num_examples > 0: + # convert to numpy array for easy calculation. + __rank_list = torch.tensor(self._rank_list).float() + metrics = { + "r@1": torch.mean((__rank_list <= 1).float()).item(), + "r@5": torch.mean((__rank_list <= 5).float()).item(), + "r@10": torch.mean((__rank_list <= 10).float()).item(), + "mean": torch.mean(__rank_list).item(), + "mrr": torch.mean(__rank_list.reciprocal()).item(), + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._rank_list = [] + + +class NDCG(object): + def __init__(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + + def observe( + self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor + ): + """ + Observe model output scores and target ground truth relevance and + accumulate NDCG metric. + + Parameters + ---------- + predicted_scores: torch.Tensor + A tensor of shape (batch_size, num_options), because dense + annotations are available for 1 randomly picked round out of 10. + target_relevance: torch.Tensor + A tensor of shape same as predicted scores, indicating ground truth + relevance of each answer option for a particular round. + """ + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, 1, num_options) + predicted_scores = predicted_scores.unsqueeze(1) + predicted_ranks = scores_to_ranks(predicted_scores) + + # shape: (batch_size, num_options) + predicted_ranks = predicted_ranks.squeeze(1) + batch_size, num_options = predicted_ranks.size() + + k = torch.sum(target_relevance != 0, dim=-1) + + # shape: (batch_size, num_options) + _, rankings = torch.sort(predicted_ranks, dim=-1) + # Sort relevance in descending order so highest relevance gets top rnk. + _, best_rankings = torch.sort( + target_relevance, dim=-1, descending=True + ) + + # shape: (batch_size, ) + batch_ndcg = [] + for batch_index in range(batch_size): + + num_relevant = k[batch_index] + dcg = self._dcg( + rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + best_dcg = self._dcg( + best_rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + batch_ndcg.append(dcg / best_dcg) + + self._ndcg_denominator += batch_size + self._ndcg_numerator += sum(batch_ndcg) + + def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): + sorted_relevance = relevance[rankings].cpu().float() + discounts = torch.log2(torch.arange(len(rankings)).float() + 2) + return torch.sum(sorted_relevance / discounts, dim=-1) + + def retrieve(self, reset: bool = True): + if self._ndcg_denominator > 0: + metrics = { + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 \ No newline at end of file diff --git a/utils/optimizer.py b/utils/optimizer.py new file mode 100644 index 0000000..53937a2 --- /dev/null +++ b/utils/optimizer.py @@ -0,0 +1,35 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2020 Ross Wightman +""" +import re +import torch +from torch import optim as optim +from utils.dist import is_main_process +import glog as logger +# from transformers import create_optimizer +# from transformers import AdamW +# import math + + +def create_optimizer(config, model): + lr_scale = config.get('lr_layer_decay', 1) + weight_decay = config.get('weight_decay', 0.01) + + optim_params = model.get_optimizer_params(weight_decay, lr_scale) + + num_parameters = 0 + for p_group in optim_params: + for p in p_group['params']: + num_parameters += p.data.nelement() + logger.info('number of trainable parameters: {}'.format(num_parameters)) + + lr = config.get('lr', 1e-4) + betas = config.get('opt_betas', [0.9, 0.999]) + + optimizer = torch.optim.AdamW( + optim_params, + lr=float(lr), + betas=betas + ) + + return optimizer diff --git a/utils/scheduler.py b/utils/scheduler.py new file mode 100644 index 0000000..7df4d2a --- /dev/null +++ b/utils/scheduler.py @@ -0,0 +1,240 @@ +""" Scheduler Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch.optim import Optimizer +import math +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler +import math + + +# class LinearWarmupStepLRScheduler: +# def __init__( +# self, +# optimizer, +# max_epoch, +# min_lr, +# init_lr, +# decay_rate=1, +# warmup_start_lr=-1, +# warmup_steps=0, +# **kwargs +# ): +# self.optimizer = optimizer + +# self.max_epoch = max_epoch +# self.min_lr = min_lr + +# self.decay_rate = decay_rate + +# self.init_lr = init_lr +# self.warmup_steps = warmup_steps +# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + +# def step(self, cur_epoch, cur_step): +# if cur_epoch == 0: +# warmup_lr_schedule( +# step=cur_step, +# optimizer=self.optimizer, +# max_step=self.warmup_steps, +# init_lr=self.warmup_start_lr, +# max_lr=self.init_lr, +# ) +# else: +# step_lr_schedule( +# epoch=cur_epoch, +# optimizer=self.optimizer, +# init_lr=self.init_lr, +# min_lr=self.min_lr, +# decay_rate=self.decay_rate, +# ) + + +# class LinearWarmupCosineLRScheduler: +# def __init__( +# self, +# optimizer, +# max_epoch, +# min_lr, +# init_lr, +# warmup_steps=0, +# warmup_start_lr=-1, +# **kwargs +# ): +# self.optimizer = optimizer + +# self.max_epoch = max_epoch +# self.min_lr = min_lr + +# self.init_lr = init_lr +# self.warmup_steps = warmup_steps +# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + +# def step(self, cur_epoch, cur_step): +# # assuming the warmup iters less than one epoch +# if cur_epoch == 0: +# warmup_lr_schedule( +# step=cur_step, +# optimizer=self.optimizer, +# max_step=self.warmup_steps, +# init_lr=self.warmup_start_lr, +# max_lr=self.init_lr, +# ) +# else: +# cosine_lr_schedule( +# epoch=cur_epoch, +# optimizer=self.optimizer, +# max_epoch=self.max_epoch, +# init_lr=self.init_lr, +# min_lr=self.min_lr, +# ) + + +# class ConstantLRScheduler: +# def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs): +# self.optimizer = optimizer +# self.lr = init_lr +# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr +# self.warmup_steps = warmup_steps + +# def step(self, cur_epoch, cur_step): +# if cur_epoch == 0: +# warmup_lr_schedule( +# step=cur_step, +# optimizer=self.optimizer, +# max_step=self.warmup_steps, +# init_lr=self.warmup_start_lr, +# max_lr=self.lr, +# ) +# else: +# for param_group in self.optimizer.param_groups: +# param_group["lr"] = self.lr + + +# schedulers = { +# 'constant_lr': ConstantLRScheduler, +# 'linear_warmup_cosine_lr': LinearWarmupCosineLRScheduler, +# 'linear_warmup_step_lr': LinearWarmupStepLRScheduler +# } + + +# def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): +# """Decay the learning rate""" +# lr = (init_lr - min_lr) * 0.5 * ( +# 1.0 + math.cos(math.pi * epoch / max_epoch) +# ) + min_lr +# for param_group in optimizer.param_groups: +# param_group["lr"] = lr + + +# def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): +# """Warmup the learning rate""" +# lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) +# for param_group in optimizer.param_groups: +# param_group["lr"] = lr + + +# def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): +# """Decay the learning rate""" +# lr = max(min_lr, init_lr * (decay_rate**epoch)) +# for param_group in optimizer.param_groups: +# param_group["lr"] = lr + + +# def create_scheduler(config, optimizer): +# scheduler_cls = schedulers[config.get('scheduler', 'constant_lr')] +# max_epoch = config.epochs +# min_lr = config.min_lr +# init_lr = config.lr +# warmup_start_lr = config.get('warmup_lr', -1) +# warmup_steps = config.get('warmup_steps', 0) + +# scheduler = scheduler_cls( +# optimizer=optimizer, +# max_epoch=max_epoch, +# min_lr=min_lr, +# init_lr=init_lr, +# decay_rate=None, +# warmup_start_lr=warmup_start_lr, +# warmup_steps=warmup_steps +# ) + +# return scheduler + + + +class WarmupLinearScheduleNonZero(_LRScheduler): + """ Linear warmup and then linear decay. + Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps. + Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps. + """ + def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + self.min_lr = min_lr + super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + step = self.last_epoch + if step < self.warmup_steps: + lr_factor = float(step) / float(max(1, self.warmup_steps)) + else: + lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) + + return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs] + + +def create_scheduler(config, optimizer): + lr_scheduler = None + if config['scheduler'] == 'cosine': + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=config['num_warmup_steps'], + num_training_steps=config['num_training_steps'], + num_cycles=0.5, + min_lr_multi=config['min_lr_multi'] + ) + elif config['scheduler'] == 'linear': + lr_scheduler = WarmupLinearScheduleNonZero( + optimizer, + config['num_warmup_steps'], + config['num_training_steps'], + min_lr = config['min_lr'] + ) + return lr_scheduler + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, + num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 +): + """ + Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py + + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + min_lr_multi (`float`, *optional*, defaults to 0): + The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch)