initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

0
models/__init__.py Normal file
View file

1216
models/backbones/Qformer.py Executable file

File diff suppressed because it is too large Load diff

View file

247
models/backbones/base_model.py Executable file
View file

@ -0,0 +1,247 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import numpy as np
import torch
import torch.nn as nn
from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
from models.common.utils import get_abs_path, is_url
from omegaconf import OmegaConf
class BaseModel(nn.Module):
"""Base class for models."""
def __init__(self):
super().__init__()
@property
def device(self):
return list(self.parameters())[0].device
def load_checkpoint(self, url_or_filename):
"""
Load from a finetuned checkpoint.
This should expect no mismatch in the model keys and the checkpoint keys.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
if "model" in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
@classmethod
def from_pretrained(cls, model_type):
"""
Build a pretrained model from default configuration file, specified by model_type.
Args:
- model_type (str): model type, specifying architecture and checkpoints.
Returns:
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
"""
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
model = cls.from_config(model_cfg)
return model
@classmethod
def default_config_path(cls, model_type):
assert (
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
), "Unknown model type {}".format(model_type)
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
def load_checkpoint_from_config(self, cfg, **kwargs):
"""
Load checkpoint as specified in the config file.
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
When loading the pretrained model, each task-specific architecture may define their
own load_from_pretrained() method.
"""
load_finetuned = cfg.get("load_finetuned", True)
if load_finetuned:
finetune_path = cfg.get("finetuned", None)
assert (
finetune_path is not None
), "Found load_finetuned is True, but finetune_path is None."
self.load_checkpoint(url_or_filename=finetune_path)
else:
# load pre-trained weights
pretrain_path = cfg.get("pretrained", None)
assert "Found load_finetuned is False, but pretrain_path is None."
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
def before_evaluation(self, **kwargs):
pass
def show_n_params(self, return_str=True):
tot = 0
for p in self.parameters():
w = 1
for x in p.shape:
w *= x
tot += w
if return_str:
if tot >= 1e6:
return "{:.1f}M".format(tot / 1e6)
else:
return "{:.1f}K".format(tot / 1e3)
else:
return tot
class BaseEncoder(nn.Module):
"""
Base class for primitive encoders, such as ViT, TimeSformer, etc.
"""
def __init__(self):
super().__init__()
def forward_features(self, samples, **kwargs):
raise NotImplementedError
@property
def device(self):
return list(self.parameters())[0].device
class SharedQueueMixin:
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
batch_size = image_feats.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
if idxs is not None:
idxs = concat_all_gather(idxs)
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
class MomentumDistilationMixin:
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data = param_m.data * self.momentum + param.data * (
1.0 - self.momentum
)
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
torch.distributed.all_reduce(all_gradients)
return all_gradients[torch.distributed.get_rank()]
def all_gather_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = torch.distributed.get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
# tensor_all = GatherLayer.apply(tensors)
tensor_all = GatherLayer.apply(tensors)
return torch.cat(tensor_all, dim=0)
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
# if use distributed training
if not is_dist_avail_and_initialized():
return tensor
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
)
return torch.index_select(x, dim, order_index.to(x.device))

View file

View file

@ -0,0 +1,107 @@
import logging
import torch
from models.utils import (interpolate_pos_relative_bias_beit,
load_temp_embed_with_mismatch)
logger = logging.getLogger(__name__)
def interpolate_pos_embed_beit(state_dict, new_model):
"""interpolate the positional embeddings.
The spatial pe is relative and temporal pe is absolute.
additional temporal pe is padded with 0.
Args:
state_dict (dict): The state_dict.
new_model (nn.Module): The created model.
Returns: dict. The state_dict with updated positional embeddings.
"""
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=state_dict,
state_dict_new=new_model.state_dict(),
patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape,
)
# absolute temporal pos bias
temporal_pe_key = "beit.embeddings.temporal_position_embeddings"
if temporal_pe_key in state_dict:
logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
temp_embed_old=state_dict[temporal_pe_key],
temp_embed_new=new_model.state_dict()[temporal_pe_key],
)
return state_dict
def extract_beit_from_vindlu(vindlu_state_dict):
beit_state_dict = {}
beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k]
for param_name in beit_param_names:
new_name = param_name.replace('vision_encoder.', '')
beit_state_dict[new_name] = vindlu_state_dict[param_name]
return beit_state_dict
def build_beit(model_config, image_res, checkpoint=False):
"""build beit with configuration.
Args:
config (dict): The configs for beit.
image_res (int): The image resolution.
checkpoint (bool): Whether to enable gradient checkpointing.
Returns: nn.Module
"""
from .st_beit import BeitConfig as config_cls
from .st_beit import BeitModel as model_cls
vindlu_state_dict = torch.load(model_config['vindlu_path'])['model']
state_dict = extract_beit_from_vindlu(vindlu_state_dict)
model_config = model_config['beit_config_json']
logger.info(
f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}."
)
# BEiT uses average pooled tokens instead of [CLS] used by other models
aux_kwargs = {"add_pooling_layer": True}
# tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs)
# tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs)
# state_dict = tmp_model.state_dict()
# del tmp_model
logger.info(f"Init new model with new image size {image_res}, and load weights.")
# other_cfg = model_config.temporal_modeling
other_cfg = {}
vit_config = config_cls.from_pretrained(
model_config['pretrained'], image_size=image_res, **other_cfg
)
# vit_config.update(model_config)
model = model_cls(config=vit_config, **aux_kwargs)
if checkpoint:
model.gradient_checkpointing_enable()
# interpolate relative pos bias
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=state_dict,
state_dict_new=model.state_dict(),
patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
)
# del prompt_bias_table
for k in list(state_dict.keys()):
if "prompt_bias_table" in k:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
logger.info(msg)
return model

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,71 @@
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
"""build text encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_size = model_config['expert_size']
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
# bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
if expert_type=='modality':
if modality_type == 'vis':
bert_config.cross_attention_freq = 2
else:
bert_config.cross_attention_freq = -1
else:
bert_config.cross_attention_freq = 1
if pretrain:
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
output_loading_info=True,
)
else:
text_encoder, loading_info = BertModel.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
add_pooling_layer=True,
output_loading_info=True,
)
return text_encoder
def build_bert_decoder(model_config, checkpoint):
"""build text decoder the same as the multimodal encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.fusion_layer = 0
bert_config.num_hidden_layers = (
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
)
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
output_loading_info=True,
)
return text_decoder

View file

@ -0,0 +1,546 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Bert."""
import collections
import os
import unicodedata
from typing import List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(PreTrainedTokenizer):
r"""
Construct a BERT tokenizer. Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to do basic tokenization before WordPiece.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
vocab_file)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab, unk_token=self.unk_token)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X ``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix +
"-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(
vocab_file)
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(
set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens

File diff suppressed because it is too large Load diff

268
models/backbones/blip2.py Executable file
View file

@ -0,0 +1,268 @@
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
# import .backbones.common.dist_utils as dist_utils
# from minigpt4.common.dist_utils import download_cached_file
# from minigpt4.common.utils import is_url
# from minigpt4.common.logger import MetricLogger
from models.backbones.base_model import BaseModel
from models.backbones.Qformer import BertConfig, BertLMHeadModel
from models.backbones.eva_vit import create_eva_vit_g
from transformers import BertTokenizer
class Blip2Base(BaseModel):
@classmethod
def init_tokenizer(cls):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
):
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
return visual_encoder, ln_vision
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def get_optimizer_params(self, weight_decay, lr_scale=1):
vit_num_layers = self.visual_encoder.get_num_layer()
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
parameter_group_names = {}
parameter_group_vars = {}
for name, param in self.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias"):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if 'visual_encoder' in name:
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if layer_id is not None:
scale = lr_scales[layer_id]
else:
scale = 1
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
# import json
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
optim_params = list(parameter_group_vars.values())
return optim_params
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def compute_sim_matrix(model, data_loader, **kwargs):
k_test = kwargs.pop("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = model.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(model.device)
text_feat = model.forward_text(text_input)
text_embed = F.normalize(model.text_proj(text_feat))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
vit_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = image.to(model.device)
image_feat, vit_feat = model.forward_image(image)
image_embed = model.vision_proj(image_feat)
image_embed = F.normalize(image_embed, dim=-1)
vit_feats.append(vit_feat.cpu())
image_embeds.append(image_embed)
vit_feats = torch.cat(vit_feats, dim=0)
image_embeds = torch.cat(image_embeds, dim=0)
sims_matrix = []
for image_embed in image_embeds:
sim_q2t = image_embed @ text_embeds.t()
sim_i2t, _ = sim_q2t.max(0)
sims_matrix.append(sim_i2t)
sims_matrix = torch.stack(sims_matrix, dim=0)
score_matrix_i2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(model.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[topk_idx],
text_atts=text_atts[topk_idx],
).float()
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(model.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[start + i].repeat(k_test, 1),
text_atts=text_atts[start + i].repeat(k_test, 1),
).float()
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()

110
models/backbones/blip2_outputs.py Executable file
View file

@ -0,0 +1,110 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_outputs import (
ModelOutput,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
@dataclass
class BlipSimilarity(ModelOutput):
sim_i2t: torch.FloatTensor = None
sim_t2i: torch.FloatTensor = None
sim_i2t_m: Optional[torch.FloatTensor] = None
sim_t2i_m: Optional[torch.FloatTensor] = None
sim_i2t_targets: Optional[torch.FloatTensor] = None
sim_t2i_targets: Optional[torch.FloatTensor] = None
@dataclass
class BlipIntermediateOutput(ModelOutput):
"""
Data class for intermediate outputs of BLIP models.
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
decoder_labels (torch.LongTensor): labels for the captioning loss.
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
"""
# uni-modal features
image_embeds: torch.FloatTensor = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds_m: Optional[torch.FloatTensor] = None
text_embeds_m: Optional[torch.FloatTensor] = None
# intermediate outputs of multimodal encoder
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
itm_logits: Optional[torch.FloatTensor] = None
itm_labels: Optional[torch.LongTensor] = None
# intermediate outputs of multimodal decoder
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
decoder_labels: Optional[torch.LongTensor] = None
@dataclass
class BlipOutput(ModelOutput):
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
sims: Optional[BlipSimilarity] = None
intermediate_output: BlipIntermediateOutput = None
loss: Optional[torch.FloatTensor] = None
loss_itc: Optional[torch.FloatTensor] = None
loss_itm: Optional[torch.FloatTensor] = None
loss_lm: Optional[torch.FloatTensor] = None
@dataclass
class BlipOutputFeatures(ModelOutput):
"""
Data class of features from BlipFeatureExtractor.
Args:
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
The first embedding or feature is for the [CLS] token.
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
"""
image_embeds: Optional[torch.FloatTensor] = None
image_embeds_proj: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
text_embeds_proj: Optional[torch.FloatTensor] = None
multimodal_embeds: Optional[torch.FloatTensor] = None

View file

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionEncoder(nn.Module):
def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_encoder_name = encoder_name
# self.select_layer = args.mm_vision_select_layer
# self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.select_layer = -1
self.select_feature = "patch"
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
self.vision_encoder.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, :]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
# print("image feature shape", image_features.shape)
# print(type(image_forward_outs))
# print(type(image_forward_outs.shape))
# image_features = image_forward_outs.to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_encoder.dtype
@property
def device(self):
return self.vision_encoder.device
@property
def config(self):
if self.is_loaded:
return self.vision_encoder.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2

View file

@ -0,0 +1,141 @@
import glog as logger
import re
import json
from peft import LoraConfig, get_peft_model
from .xflan_t5 import T5Config, T5ForConditionalGeneration
from .xbart import BartConfig, BartForConditionalGeneration, BartEncoder, BartForCausalLM
def build_encoder_decoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Encoder Decoder [Type = {}]'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
if model_config['use_decoder_only']:
model_cls = BartForCausalLM
else:
model_cls = BartForConditionalGeneration
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
enc_dec_config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = enc_dec_config.d_model
# enc_dec_config.encoder_layers = enc_dec_config.encoder_layers - model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
enc_dec = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=enc_dec_config
)
# first_k = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
# enc_dec.model.encoder.remove_first_k_layers(first_k)
# get the last encoder layers
# enc_dec.
if model_config['use_lora_enc_dec']:
# load the lora config
with open(model_config['lora_config'], 'r') as f:
lora_config = json.load(f)
# get the linear layer to perform LoRA on
model_modules = str(enc_dec.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)
names = []
# Print the names of the Linear layers
for name in linear_layer_names:
names.append(name)
target_modules = list(set(names))
lora_config['target_modules'] = target_modules
lora_config = LoraConfig(**lora_config)
enc_dec = get_peft_model(enc_dec, lora_config)
return enc_dec
def build_encoder(model_config, expert_type, modality=None):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
log_txt = '[INFO] Loading {} Expert'.format(expert_type)
if modality is not None:
log_txt += ' [Modality = {}]'.format(modality)
log_txt += ' [Type = {}]'.format(model_config['enc_dec_name'])
logger.info(log_txt)
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartEncoder
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
config.modality_expert_layers = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
config.grounding_expert_layers = model_config['num_layers_grounding_expert_{}'.format(model_config['enc_dec_family'])]
model_config['enc_dec_dim'] = config.d_model
expert = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config,
expert_type=expert_type,
modality=modality
)
if model_config['use_lora_expert']:
# load the lora config
with open(model_config['lora_config'], 'r') as f:
lora_config = json.load(f)
# get the linear layer to perform LoRA on
model_modules = str(expert.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)
names = []
# Print the names of the Linear layers
for name in linear_layer_names:
names.append(name)
target_modules = list(set(names))
lora_config['target_modules'] = target_modules
lora_config = LoraConfig(**lora_config)
expert = get_peft_model(expert, lora_config)
# expert = model_cls(
# config=config,
# expert_type=expert_type,
# modality=modality
# )
return expert

View file

@ -0,0 +1,65 @@
from .xflan_t5 import T5Config, T5ForConditionalGeneration
from .xbart_original import BartConfig, BartForConditionalGeneration, BartEncoder
import glog as logger
def build_encoder_decoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Encoder Decoder: {}'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartForConditionalGeneration
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = config.d_model
enc_dec = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config
)
return enc_dec
def build_encoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Expert as Encoder of {}'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartEncoder
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = config.d_model
config.encoder_layers = model_config['num_layers_modality_expert']
expert = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config
)
return expert

View file

@ -0,0 +1,19 @@
from typing import Optional, Tuple
import torch
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
@dataclass
class Seq2SeqV2DialOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

455
models/backbones/eva_vit.py Executable file
View file

@ -0,0 +1,455 @@
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from models.common.dist_utils import download_cached_file
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def get_num_layer(self, var_name=""):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("rel_pos_bias"):
return len(self.blocks) - 1
elif var_name.startswith("blocks"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return len(self.blocks)
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
# depth = 37,
num_heads=1408//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
cached_file = download_cached_file(
url, check_hash=False, progress=True
)
state_dict = torch.load(cached_file, map_location="cpu")
interpolate_pos_embed(model,state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
# print(incompatible_keys)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model

View file

@ -0,0 +1,895 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
from transformers import LlamaTokenizer
from transformers import BitsAndBytesConfig
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
import time
import numpy as np
from minigpt4.models import policies
@registry.register_model("mini_gpt4_llama_v2")
class MiniGPT4_llama_v2(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
lora_r = 8,
lora_target_modules = ["q_proj","v_proj"],
lora_alpha=16,
# lora_r = 16,
# lora_target_modules = ["q_proj","v_proj","v_proj"],
lora_dropout= 0.05,
ckpt_path = "",
system_prompt= False,
chat_template=False,
token_pooling=True,
use_grad_checkpoint_llm=False,
max_context_len=3800,
remove_template = False,
):
super().__init__()
if "Mistral" in llama_model:
from minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
print("Mistral model")
self.model_type = "Mistral"
else:
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
print("Llama model")
self.model_type = "Llama"
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
self.token_pooling = token_pooling
self.remove_template = remove_template
print("token pooling", self.token_pooling)
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
self.max_context_len = max_context_len
self.chat_template = chat_template
# print('Loading VIT')
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
# )
if freeze_vit:
# vit_precision="fp32"
print("vit precision", vit_precision)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
logging.info("freeze vision encoder")
print("freeze the vision encoder")
else:
vit_precision="fp32"
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
print("unfreeze the vision encoder")
print('Loading VIT Done')
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
# assert False
print('Loading LLAMA')
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model,use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
self.system_prompt = system_prompt
print("self.low_resource",self.low_resource)
if self.low_resource:
self.llama_model = llm_model.from_pretrained(
llama_model,
torch_dtype=torch.float16,
# torch_dtype = torch.bfloat16,
load_in_8bit=True,
# device_map = "balanced"
# device_map="auto",
device_map={'':torch.cuda.current_device()},
# device_map={'':0}
)
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16,
# )
# self.llama_model = llm_model.from_pretrained(
# llama_model,
# torch_dtype=torch.bfloat16,
# device_map={'':torch.cuda.current_device()},
# quantization_config=bnb_config,
# )
else:
self.llama_model = llm_model.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)
# self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
self.llama_model = prepare_model_for_int8_training(self.llama_model)
loraconfig = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.llama_model = get_peft_model(self.llama_model, loraconfig)
# if ckpt_path:
# print('load the llm under lora')
# ckpt = torch.load(ckpt_path)
# set_peft_model_state_dict(self.llama_model,ckpt)
self.llama_model.print_trainable_parameters()
if self.use_grad_checkpoint_llm:
self.llama_model.gradient_checkpointing_enable()
# if not self.low_resource:
# for name, param in self.llama_model.named_parameters():
# if "embed_token" in name:
# param.data = param.data.float()
# param.requires_grad = True
print('Loading LLAMA Done')
if self.token_pooling:
self.llama_proj = nn.Linear(
1408*4, self.llama_model.config.hidden_size
)
else:
self.llama_proj = nn.Linear(
1408, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
bs, pn, hs = image_embeds.shape
if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def get_context_emb(self, prompt, img_list):
img_device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
# # truncate the length of tokens to the max context window
# mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair]
# mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1)
# # check if the number of token in the second dimention is more than the max context window then truncate it
# context_window=self.max_context_len-seg_embs[-1].shape[1]
# if mixed_embs_without_instruction.shape[1] > context_window :
# mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window]
# mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1)
# print("mixed_embs",mixed_embs.shape)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
# print("p_embed device",p_tokens.input_ids.device)
# print("p_tokens",img_embeds.device)
# print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device)
p_embed = self.embed_tokens(p_tokens.input_ids)
# print("model device",self.llama_model.get_device())
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
# max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
max_length = self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
# print('===================================')
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
# print('check out emb: ', output_embs[i][:2])
# print('check out pad emb: ', output_embs[i][-2:])
# print('+++++++++++++++++++++++++++++++++++')
#
# print('check attn before: ', input_atts[i][:this_input_ones])
# print('check attn after: ', input_atts[i][this_input_ones:])
# print('check attn gt before: ', output_atts[i][:3])
# print('check attn gt after: ', output_atts[i][-3:])
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def get_conv_emb(self, conv_q, conv_a, conv_img):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
regress_embs_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
assigned_imgs = conv_img[batch_idx]
questions = [self.prompt_wrap(
img_embeds=img,
atts_img=None,
prompts=[q],
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
q_embs = [emb for emb, _ in questions]
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
cur_emb = []
cur_target = []
for i in range(len(questions)):
cur_emb.append(q_embs[i])
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
cur_emb.append(self.embed_tokens(answers[i].input_ids))
cur_target.append(answers[i].input_ids)
cur_emb = torch.cat(cur_emb, dim=1)
cur_target = torch.cat(cur_target, dim=1)
regress_embs_list.append(cur_emb)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = regress_embs_list[batch_idx].shape[1]
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
regress_attn[batch_idx, :cur_len] = 1
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
return regress_embeds, regress_attn, targets
def preparing_embedding(self, samples):
def remove_special_tokens(data):
# if "instruction_input" in data:
data = [instruct.replace(" [caption]","") for instruct in data]
data = [instruct.replace(" [vqa]","") for instruct in data]
data = [instruct.replace(" [grounding]","") for instruct in data]
data = [instruct.replace(" [identify]","") for instruct in data]
data = [instruct.replace(" [refer]","") for instruct in data]
return data
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
# print("img_embeds shape",img_embeds.shape)
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_img = assign_imgs(conv_q, img_embeds)
if self.chat_template:
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
# print("instruction before", instruction)
if self.remove_template:
instruction = remove_special_tokens(instruction)
# print("instruction after", instruction)
if self.chat_template:
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction="mean"):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
print("inputs_embeds shape",inputs_embeds.shape)
print("cond_embeds shape",cond_embeds.shape)
print("regress_embeds shape",regress_embeds.shape)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = attention_mask[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# print length of instruction_input and answer words
# for i in range (len(samples["instruction_input"])):
# print("instruction_input length",len(samples["instruction_input"][i].split(" ")))
# print("answer length",len(samples["answer"][i].split(" ")))
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
print("targets shape",targets.shape)
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
images,
texts,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
return_video_temporal_features=False,
img_embeds=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
if img_embeds is None:
img_embeds, atts_img = self.encode_img(images.to(self.device))
else:
# Use images features from the input(4,45,64,5632)
img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
img_embeds= img_embeds.to(self.device)
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
print("img_embeds shape",img_embeds.shape)
if lengths is not None:
image_lists = []
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
for idx, img_embed in enumerate(img_embeds):
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
else:
image_lists = [[image_emb[None]] for image_emb in img_embeds]
assert len(texts) == len(image_lists)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
# print("inputs_embeds shape",embs.shape)
# print("attention_mask shape",attn_mask.shape)
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
if self.model_type == "Llama":
context_window = 3700
else:
context_window = 7500
if embs.shape[1] > context_window:
embs = embs[:, -context_window:]
attn_mask = attn_mask[:, -context_window:]
print("inputs_embeds shape",embs.shape)
print("attention_mask shape",attn_mask.shape)
with self.maybe_autocast():
if return_video_temporal_features:
last_hidden_state = self.llama_model(
inputs_embeds=embs,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[-1]
video_temporal_features = last_hidden_state.mean(dim=1)
# normalize the temporal features using L2 norm
# video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
if return_video_temporal_features:
return answers, video_temporal_features
else:
return answers
@torch.no_grad()
def generate_text_only(
self,
images,
seg_tokens,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
return_video_temporal_features=False,
img_embeds=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# seg_tokens=[]
# for i, text in enumerate(texts):
# seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids)
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
# seg_embs = torch.cat(seg_embs, dim=1)
# print("seg_embs shape",seg_embs.shape)
# batch_embs=[seg_embs]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
print("inputs_embeds shape",embs.shape)
print("attention_mask shape",attn_mask.shape)
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
def predict_answers(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=128,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for open-ended VQA
'''
images = samples["image"].cuda()
texts = samples["instruction_input"]
output_text = self.generate(
images=images,
texts=texts,
num_beams=num_beams,
max_new_tokens=max_len,
min_length=min_len,
length_penalty=length_penalty
)
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
output_text = self._lemmatize(output_text)
return output_text
def predict_class(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=5,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for multi-choice VQA
'''
image = samples["image"].cuda()
instruction = samples['instruction_input']
answers = samples["choices"]
num_cand = samples["num_choices"]
ranks = self.multi_select(image, instruction, answers, num_cand)
pred_ans = []
for i, rank in enumerate(ranks):
pred = answers[rank[0]][i]
pred_ans.append(pred)
return pred_ans
def embed_tokens(self, token_ids):
try:
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
except AttributeError:
embeds = self.llama_model.model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r",64)
lora_alpha = cfg.get("lora_alpha",16)
chat_template = cfg.get("chat_template",False)
system_prompt = cfg.get("system_prompt", False)
token_pooling = cfg.get("token_pooling",True)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
remove_template = cfg.get("remove_template", False)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r = lora_r,
lora_alpha = lora_alpha,
chat_template = chat_template,
system_prompt = system_prompt,
token_pooling = token_pooling,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
remove_template = remove_template
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
def assign_imgs(batched_instruct_list, batched_img_embeds):
'''this function is used when the data is interleaved.
the interlevaed data is separated, and this function assign
corresponding image embeddings to each segment'''
if len(batched_img_embeds.shape) == 3:
batched_img_embeds = batched_img_embeds[:, None]
batched_assigned = []
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
img_idx = 0
assigned_img = []
n_assigned = []
for instruct in instruct_list:
n_img = instruct.count('<ImageHere>')
if n_img > 0: # this instruction include images.
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
img_idx += n_img
n_assigned.append(n_img)
else: # this instruction doesn't include images
assigned_img.append(None)
n_assigned.append(None)
batched_assigned.append(assigned_img)
return batched_assigned

709
models/backbones/mini_gpt4v.py Executable file
View file

@ -0,0 +1,709 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training
)
import time
import numpy as np
from minigpt4.models import policies
@registry.register_model("mini_gpt4v")
class MiniGPT4v(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
lora_r = 8,
lora_target_modules = ["q_proj","v_proj"],
lora_alpha=16,
# lora_r = 16,
# lora_target_modules = ["q_proj","v_proj","v_proj"],
lora_dropout= 0.05,
ckpt_path = "",
system_prompt= False,
chat_template=False,
token_pooling=True,
use_grad_checkpoint_llm=False,
max_context_len=3800,
remove_template = False,
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
self.token_pooling = token_pooling
self.remove_template = remove_template
print("token pooling", self.token_pooling)
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
self.max_context_len = max_context_len
self.chat_template = chat_template
# print('Loading VIT')
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
# )
print("vit precision", vit_precision)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
)
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
logging.info("freeze vision encoder")
print("freeze the vision encoder")
print('Loading VIT Done')
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
# assert False
print('Loading LLAMA')
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
if 'CodeLlama' in llama_model:
self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
else:
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
self.system_prompt = system_prompt
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
quantization_config=bnb_config,
device_map={"": 0}
)
# self.llama_model.gradient_checkpointing_enable()
self.llama_model = prepare_model_for_kbit_training(self.llama_model)
# self.llama_model.print_trainable_parameters()
print('Loading LLAMA Done')
self.merge_n = 3
self.llama_proj = nn.Linear(
1408 * self.merge_n**2, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
bs, ch, w, h = image.shape
assert w % 224 == 0
bw = w // 224
assert h % 224 == 0
bh = h // 224
image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224
image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)
with self.maybe_autocast():
image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)
image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs
image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def get_context_emb(self, prompt, img_list):
img_device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
# print('===================================')
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
# print('check out emb: ', output_embs[i][:2])
# print('check out pad emb: ', output_embs[i][-2:])
# print('+++++++++++++++++++++++++++++++++++')
#
# print('check attn before: ', input_atts[i][:this_input_ones])
# print('check attn after: ', input_atts[i][this_input_ones:])
# print('check attn gt before: ', output_atts[i][:3])
# print('check attn gt after: ', output_atts[i][-3:])
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def get_conv_emb(self, conv_q, conv_a, conv_img):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
regress_embs_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
assigned_imgs = conv_img[batch_idx]
questions = [self.prompt_wrap(
img_embeds=img,
atts_img=None,
prompts=[q],
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
q_embs = [emb for emb, _ in questions]
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
cur_emb = []
cur_target = []
for i in range(len(questions)):
cur_emb.append(q_embs[i])
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
cur_emb.append(self.embed_tokens(answers[i].input_ids))
cur_target.append(answers[i].input_ids)
cur_emb = torch.cat(cur_emb, dim=1)
cur_target = torch.cat(cur_target, dim=1)
regress_embs_list.append(cur_emb)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = regress_embs_list[batch_idx].shape[1]
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
regress_attn[batch_idx, :cur_len] = 1
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
return regress_embeds, regress_attn, targets
def preparing_embedding(self, samples):
def remove_special_tokens(data):
# if "instruction_input" in data:
data = [instruct.replace(" [caption]","") for instruct in data]
data = [instruct.replace(" [vqa]","") for instruct in data]
data = [instruct.replace(" [grounding]","") for instruct in data]
data = [instruct.replace(" [identify]","") for instruct in data]
data = [instruct.replace(" [refer]","") for instruct in data]
return data
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_img = assign_imgs(conv_q, img_embeds)
if self.chat_template:
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
# print("instruction before", instruction)
if self.remove_template:
instruction = remove_special_tokens(instruction)
# print("instruction after", instruction)
if self.chat_template:
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction="mean"):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = attention_mask[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
images,
texts,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
img_embeds, atts_img = self.encode_img(images.to(self.device))
if lengths is not None:
image_lists = []
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
for idx, img_embed in enumerate(img_embeds):
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
else:
image_lists = [[image_emb[None]] for image_emb in img_embeds]
assert len(texts) == len(image_lists)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
def predict_answers(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=128,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for open-ended VQA
'''
images = samples["image"].cuda()
texts = samples["instruction_input"]
output_text = self.generate(
images=images,
texts=texts,
num_beams=num_beams,
max_new_tokens=max_len,
min_length=min_len,
length_penalty=length_penalty
)
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
output_text = self._lemmatize(output_text)
return output_text
def predict_class(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=5,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for multi-choice VQA
'''
image = samples["image"].cuda()
instruction = samples['instruction_input']
answers = samples["choices"]
num_cand = samples["num_choices"]
ranks = self.multi_select(image, instruction, answers, num_cand)
pred_ans = []
for i, rank in enumerate(ranks):
pred = answers[rank[0]][i]
pred_ans.append(pred)
return pred_ans
def embed_tokens(self, token_ids):
try:
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
except AttributeError:
embeds = self.llama_model.model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r",64)
lora_alpha = cfg.get("lora_alpha",16)
chat_template = cfg.get("chat_template",False)
system_prompt = cfg.get("system_prompt", False)
token_pooling = cfg.get("token_pooling",True)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
remove_template = cfg.get("remove_template", False)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r = lora_r,
lora_alpha = lora_alpha,
chat_template = chat_template,
system_prompt = system_prompt,
token_pooling = token_pooling,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
remove_template = remove_template
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
def assign_imgs(batched_instruct_list, batched_img_embeds):
'''this function is used when the data is interleaved.
the interlevaed data is separated, and this function assign
corresponding image embeddings to each segment'''
if len(batched_img_embeds.shape) == 3:
batched_img_embeds = batched_img_embeds[:, None]
batched_assigned = []
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
img_idx = 0
assigned_img = []
n_assigned = []
for instruct in instruct_list:
n_img = instruct.count('<ImageHere>')
if n_img > 0: # this instruction include images.
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
img_idx += n_img
n_assigned.append(n_img)
else: # this instruction doesn't include images
assigned_img.append(None)
n_assigned.append(None)
batched_assigned.append(assigned_img)
return batched_assigned

View file

@ -0,0 +1,25 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
p="Well, I'm quite partial to a good squeeze of fresh lemon juice."
encoded_input = tokenizer(p, return_tensors='pt')
embeds = model.model.embed_tokens(encoded_input.input_ids)
print(embeds.shape)
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])

View file

@ -0,0 +1,112 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View file

@ -0,0 +1,112 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

File diff suppressed because it is too large Load diff

287
models/backbones/moes.py Normal file
View file

@ -0,0 +1,287 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
# if mask.dim() != x.dim():
# expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
# else:
# expanded_mask = mask
mask = mask.bool()
attn = attn.masked_fill(~mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class MoELayer(nn.Module):
def __init__(
self,
dim,
num_heads,
expert_type,
use_sep_spatial_temp_experts=True,
has_hist=False,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LlamaRMSNorm,
):
super().__init__()
self.has_hist = has_hist
self.use_sep_spatial_temp_experts = use_sep_spatial_temp_experts
self.norm_att = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
mlp_hidden_dim = int(dim * mlp_ratio)
if expert_type == 'modalities':
# EXPERT CONSTRUCTION
if use_sep_spatial_temp_experts:
# Spatial expert
self.norm_spatial = norm_layer(dim)
self.mlp_spatial = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Temporal expert
self.norm_temp = norm_layer(dim)
self.mlp_temp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Vis expert
self.norm_vis = norm_layer(dim)
self.mlp_vis = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# caption expert
self.norm_cap = norm_layer(dim)
self.mlp_cap = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# history expert
if has_hist:
self.norm_hist = norm_layer(dim)
self.mlp_hist = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
elif expert_type == 'fusion':
# Fusion expert
self.norm_fusion = norm_layer(dim)
self.mlp_fusion = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
else:
raise ValueError
def forward(self, x, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len=None, is_vid=False, mask=None, only_text=False, expert_permutation=None):
if self.has_hist:
assert hist_feat_len is not None
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
x = x + self.drop_path(x_shortcut)
len_init = x.size(1)
# bs, h_dim = x.size(0), x.size(-1)
# device = x.device
# if only_text:
# # end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
# # x = x[:, special_toks_indices['<caption>']: end_idx_caption, :]
# x = x + self.drop_path(self.mlp_cap(self.norm_cap(x)))
if expert_flag == 'modalities':
if self.use_sep_spatial_temp_experts:
x_spatial = x[:, :vis_feat_len]
if expert_permutation is not None:
if expert_permutation['spatial'] == 'temporal':
x_spatial = x_spatial + self.drop_path(self.mlp_temp(self.norm_temp(x_spatial)))
elif expert_permutation['spatial'] == 'caption':
x_spatial = x_spatial + self.drop_path(self.mlp_cap(self.norm_cap(x_spatial)))
elif expert_permutation['spatial'] == 'history':
x_spatial = x_spatial + self.drop_path(self.mlp_hist(self.norm_hist(x_spatial)))
elif expert_permutation['spatial'] == 'spatial':
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
x_vis = x_spatial
else:
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
x_vis = x_spatial
if is_vid:
x_temporal = x[:, vis_feat_len:2*vis_feat_len]
if expert_permutation is not None:
if expert_permutation['temporal'] == 'spatial':
x_temporal = x_temporal + self.drop_path(self.mlp_spatial(self.norm_spatial(x_temporal)))
elif expert_permutation['temporal'] == 'caption':
x_temporal = x_temporal + self.drop_path(self.mlp_cap(self.norm_cap(x_temporal)))
elif expert_permutation['temporal'] == 'history':
x_temporal = x_temporal + self.drop_path(self.mlp_hist(self.norm_hist(x_temporal)))
elif expert_permutation['temporal'] == 'temporal':
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
else:
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
x_vis = torch.concat([x_spatial, x_temporal], dim=1)
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
else:
x_vis = x[:, :vis_feat_len]
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
if self.has_hist:
x_caption = x[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
if expert_permutation is not None:
if expert_permutation['caption'] == 'spatial':
x_caption = x_caption + self.drop_path(self.mlp_spatial(self.norm_spatial(x_caption)))
elif expert_permutation['caption'] == 'temporal':
x_caption = x_caption + self.drop_path(self.mlp_temp(self.norm_temp(x_caption)))
elif expert_permutation['caption'] == 'history':
x_caption = x_caption + self.drop_path(self.mlp_hist(self.norm_hist(x_caption)))
elif expert_permutation['caption'] == 'caption':
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
else:
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
x_history = x[:, -hist_feat_len:]
if expert_permutation is not None:
if expert_permutation['history'] == 'spatial':
x_history = x_history + self.drop_path(self.mlp_spatial(self.norm_spatial(x_history)))
elif expert_permutation['history'] == 'temporal':
x_history = x_history + self.drop_path(self.mlp_temp(self.norm_temp(x_history)))
elif expert_permutation['history'] == 'caption':
x_history = x_history + self.drop_path(self.mlp_cap(self.norm_cap(x_history)))
elif expert_permutation['history'] == 'history':
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
else:
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
# concat the features back
x = torch.cat([x_vis, x_caption, x_history], dim=1)
else:
x_caption = x[:, -cap_feat_len:]
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
x = torch.cat([x_vis, x_caption], dim=1)
assert x.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
x.size(1), len_init
)
elif expert_flag == 'fusion':
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
return x
class Pooler(nn.Module):
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
pooled_states = hidden_states[:, 0]
pooled_output = self.dense(pooled_states)
pooled_output = self.activation(pooled_output)
return pooled_output

View file

@ -0,0 +1,234 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
import warnings
from torch import Tensor
from typing import Optional, Tuple
from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig
class MoELayer(nn.Module):
def __init__(self, config, expert_type):
super(MoELayer, self).__init__()
self.config = config
self.expert_type = expert_type
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
# Shared across all experts
self.attention = BertAttention(self.bert_config)
# One for each expert
if expert_type == 'modalities':
# Spatial expert
self.intermediate_spatial = BertIntermediate(self.bert_config)
self.output_spatial = BertOutput(self.bert_config)
# Temporal expert
self.intermediate_temporal = BertIntermediate(self.bert_config)
self.output_temporal = BertOutput(self.bert_config)
# Vis Expert
self.intermediate_vis = BertIntermediate(self.bert_config)
self.output_vis = BertOutput(self.bert_config)
# Caption Expert
self.intermediate_caption = BertIntermediate(self.bert_config)
self.output_caption = BertOutput(self.bert_config)
if config.stage != 'stage_1':
# History Expert
self.intermediate_history = BertIntermediate(self.bert_config)
self.output_history = BertOutput(self.bert_config)
# Fusion expert
elif expert_type == 'fusion':
self.intermediate_fusion = BertIntermediate(self.bert_config)
self.output_fusion = BertOutput(self.bert_config)
else:
raise ValueError
self._init_weights()
def _init_weights(self):
for _, m in dict(self.named_modules()).items():
if isinstance(m, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.bert_config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
# if self.config.is_decoder:
# extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
# input_shape, attention_mask, device
# )
# else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False):
input_shape = hidden_states.size()[:-1]
# dtype = mask.dtype
# device = mask.device
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32)
self_attention_outputs = self.attention(
hidden_states,
attention_mask=extended_attention_mask,
output_attentions=output_attentions,
head_mask=None
)
attention_output = self_attention_outputs[0]
# outputs = self_attention_outputs[1:]
len_init = attention_output.size(1)
# bs, h_dim = x.size(0), x.size(-1)
# device = x.device
if expert_flag == 'modalities':
if only_text:
intermediate_output = self.intermediate_caption(attention_output)
layer_output = self.output_caption(intermediate_output, attention_output)
else:
# split the input first into different parts/modalities
unchanged = attention_output[:, :special_toks_indices['<vis>'], :]
end_idx_spatial = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
attention_spatial = attention_output[:, special_toks_indices['<vis>']:end_idx_spatial, :]
end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
attention_caption = attention_output[:, special_toks_indices['<caption>']: end_idx_caption, :]
attention_temporal, attention_history = None, None
if '<temporal>' in special_toks_indices:
end_idx_temporal = special_toks_indices['<caption>']
attention_temporal = attention_output[:, special_toks_indices['<temporal>']:end_idx_temporal, :]
if '<history>' in special_toks_indices:
end_idx_history = special_toks_indices['</s>'] + 1
attention_history = attention_output[:, special_toks_indices['<history>']:end_idx_history, :]
# Expert activation
# 1- Spatial
intermediate_spatial = self.intermediate_spatial(attention_spatial)
output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial)
output_vis = output_sapatial
# 2- Temporal
if attention_temporal is not None:
intermediate_temporal = self.intermediate_temporal(attention_temporal)
output_temporal = self.output_temporal(intermediate_temporal, attention_temporal)
attention_vis = torch.concat([output_sapatial, output_temporal], dim=1)
intermediate_vis = self.intermediate_vis(attention_vis)
output_vis = self.output_vis(intermediate_vis, attention_vis)
# 3- Caption
intermediate_caption = self.intermediate_caption(attention_caption)
output_caption = self.output_caption(intermediate_caption, attention_caption)
# 4- History
if attention_history is not None:
intermediate_history = self.intermediate_history(attention_history)
output_history = self.output_history(intermediate_history, attention_history)
output_list = [unchanged, output_vis, output_caption]
if attention_history is not None:
output_list.append(output_history)
# Concat the features back
layer_output = torch.concat(output_list, dim=1)
assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
layer_output.size(1), len_init
)
elif expert_flag == 'fusion':
intermediate_output = self.intermediate_fusion(attention_output)
layer_output = self.output_fusion(intermediate_output, attention_output)
return layer_output
class MoEPooler(nn.Module):
def __init__(self):
super(MoEPooler, self).__init__()
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
hidden_size = self.bert_config.hidden_size
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
self._init_weights()
def _init_weights(self):
for _, m in dict(self.named_modules()).items():
if isinstance(m, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def forward(self, hidden_states, idx):
pooled_states = hidden_states[:, idx]
pooled_output = self.dense(pooled_states)
pooled_output = self.activation(pooled_output)
return pooled_output

View file

@ -0,0 +1,247 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
if mask.dim() != x.dim():
expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
else:
expanded_mask = mask
expanded_mask = expanded_mask.bool()
attn = attn.masked_fill(~expanded_mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class MoELayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.SiLU,
norm_layer=LlamaRMSNorm,
):
super().__init__()
self.norm_att = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# EXPERT CONSTRUCTION
mlp_hidden_dim = int(dim * mlp_ratio)
# Spatial expert
self.norm_spatial = norm_layer(dim)
self.mlp_spatial = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Temporal expert
self.norm_temp = norm_layer(dim)
self.mlp_temp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Vis expert
self.norm_vis = norm_layer(dim)
self.mlp_vis = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# caption expert
self.norm_cap = norm_layer(dim)
self.mlp_cap = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# history expert
self.norm_hist = norm_layer(dim)
self.mlp_hist = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Fusion expert
self.norm_fusion = norm_layer(dim)
self.mlp_fusion = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# expert_flag:{Only Text : 00 , Only Image : 01, Fusion : 10, Text & Image : 11} (BINARY)
# expert_flag:
# 0:
def forward(self, x, special_toks_indices, expert_flag, mask=None):
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
x = x + self.drop_path(x_shortcut)
bs, h_dim = x.size(0), x.size(-1)
device = x.device
if expert_flag == 'modalities':
end_index = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
spatial_feats = x[:, special_toks_indices['<spatial>']: end_index, :]
spatial_feats = spatial_feats + self.drop_path(self.mlp_spatial(self.norm_spatial(spatial_feats)))
spatial_index = torch.arange(special_toks_indices['<spatial>'], end_index, device=device)
spatial_index = spatial_index.unsqueeze(0).unsqueeze(-1)
spatial_index = spatial_index.repeat(bs, 1, h_dim)
x = x.scatter(1, spatial_index, spatial_feats)
# x[:, special_toks_indices['<spatial>']: special_toks_indices['<temporal>'], :] = spatial_feats
end_index = special_toks_indices.get('<history>', special_toks_indices['</s>'])
caption_feats = x[:, special_toks_indices['<caption>']: end_index, :]
caption_feats = caption_feats + self.drop_path(self.mlp_cap(self.norm_cap(caption_feats)))
caption_index = torch.arange(special_toks_indices['<caption>'], end_index, device=device)
caption_index = caption_index.unsqueeze(0).unsqueeze(-1)
caption_index = caption_index.repeat(bs, 1, h_dim)
x = x.scatter(1, caption_index, caption_feats)
# x[:, special_toks_indices['<caption>']: special_toks_indices['</s>'], :] = caption_feats
if '<temporal>' in special_toks_indices:
temporal_feats = x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :]
temporal_feats = temporal_feats + self.drop_path(self.mlp_temp(self.norm_temp(temporal_feats)))
temporal_index = torch.arange(special_toks_indices['<temporal>'], special_toks_indices['<caption>'], device=device)
temporal_index = temporal_index.unsqueeze(0).unsqueeze(-1)
temporal_index = temporal_index.repeat(bs, 1, h_dim)
x = x.scatter(1, temporal_index, temporal_feats)
# x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :] = temporal_feats
vis_feats = x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :]
vis_feats = vis_feats + self.drop_path(self.mlp_vis(self.norm_vis(vis_feats)))
vis_index = torch.arange(special_toks_indices['<vis>'], special_toks_indices['<caption>'], device=device)
vis_index = vis_index.unsqueeze(0).unsqueeze(-1)
vis_index = vis_index.repeat(bs, 1, h_dim)
x = x.scatter(1, vis_index, vis_feats)
# x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :] = vis_feats
if '<history>' in special_toks_indices:
history_feats = x[:, special_toks_indices['<history>']: special_toks_indices['</s>'], :]
history_feats = history_feats + self.drop_path(self.mlp_hist(self.norm_hist(history_feats)))
history_index = torch.arange(special_toks_indices['<history>'], special_toks_indices['</s>'], device=device)
history_index = history_index.unsqueeze(0).unsqueeze(-1)
history_index = history_index.repeat(bs, 1, h_dim)
x = x.scatter(1, history_index, history_feats)
elif expert_flag == 'fusion':
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
return x, attn
# if expert_flag == 2:
# x = x + self.drop_path(self.mlp(self.norm2(x)))
# elif expert_flag == 0:
# x = (x[:, -it_split:])
# x = x + self.drop_path(self.sentence_mlp(self.sentence_norm(x)))
# elif expert_flag == 1:
# x = (x[:, :-it_split ])
# x = x + self.drop_path(self.image_mlp(self.image_norm(x)))
# elif expert_flag == 3:
# text, image = (x[:, :it_split], x[:, it_split:],)
# text = text + self.drop_path(self.sentence_mlp(self.sentence_norm(text)))
# image = image + self.drop_path(self.image_mlp(self.image_norm(image)))
# x = torch.cat([text, image], dim=1)
# elif expert_flag == 4:
# x = x + self.drop_path(self.generation_mlp(self.generation_norm(x)))
# return x, attn

0
models/common/__init__.py Executable file
View file

474
models/common/config.py Executable file
View file

@ -0,0 +1,474 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import json
from typing import Dict
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
class Config:
def __init__(self, args):
self.config = {}
self.args = args
# Register the config and configuration for setup
registry.register("configuration", self)
user_config = self._build_opt_list(self.args.options)
config = OmegaConf.load(self.args.cfg_path)
runner_config = self.build_runner_config(config)
model_config = self.build_model_config(config, **user_config)
dataset_config = self.build_dataset_config(config)
# Validate the user-provided runner configuration
# model and dataset configuration are supposed to be validated by the respective classes
# [TODO] validate the model/dataset configuration
# self._validate_runner_config(runner_config)
# Override the default configuration with user options.
self.config = OmegaConf.merge(
runner_config, model_config, dataset_config, user_config
)
def _validate_runner_config(self, runner_config):
"""
This method validates the configuration, such that
1) all the user specified options are valid;
2) no type mismatches between the user specified options and the config.
"""
runner_config_validator = create_runner_config_validator()
runner_config_validator.validate(runner_config)
def _build_opt_list(self, opts):
opts_dot_list = self._convert_to_dot_list(opts)
return OmegaConf.from_dotlist(opts_dot_list)
@staticmethod
def build_model_config(config, **kwargs):
model = config.get("model", None)
assert model is not None, "Missing model configuration file."
model_cls = registry.get_model_class(model.arch)
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
model_type = kwargs.get("model.model_type", None)
if not model_type:
model_type = model.get("model_type", None)
# else use the model type selected by user.
assert model_type is not None, "Missing model_type."
print("--------------")
print("model arch",model.arch)
print("model cls",model_cls)
model_config_path = model_cls.default_config_path(model_type=model_type)
model_config = OmegaConf.create()
# hierarchy override, customized config > default config
model_config = OmegaConf.merge(
model_config,
OmegaConf.load(model_config_path),
{"model": config["model"]},
)
return model_config
@staticmethod
def build_runner_config(config):
return {"run": config.run}
@staticmethod
def build_dataset_config(config):
datasets = config.get("datasets", None)
if datasets is None:
raise KeyError(
"Expecting 'datasets' as the root key for dataset configuration."
)
dataset_config = OmegaConf.create()
for dataset_name in datasets:
print("dataset name", dataset_name)
builder_cls = registry.get_builder_class(dataset_name)
dataset_config_type = datasets[dataset_name].get("type", "default")
dataset_config_path = builder_cls.default_config_path(
type=dataset_config_type
)
# hierarchy override, customized config > default config
dataset_config = OmegaConf.merge(
dataset_config,
OmegaConf.load(dataset_config_path),
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
)
return dataset_config
def _convert_to_dot_list(self, opts):
if opts is None:
opts = []
if len(opts) == 0:
return opts
has_equal = opts[0].find("=") != -1
if has_equal:
return opts
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
def get_config(self):
return self.config
@property
def run_cfg(self):
return self.config.run
@property
def datasets_cfg(self):
return self.config.datasets
@property
def model_cfg(self):
return self.config.model
def pretty_print(self):
logging.info("\n===== Running Parameters =====")
logging.info(self._convert_node_to_json(self.config.run))
logging.info("\n====== Dataset Attributes ======")
datasets = self.config.datasets
for dataset in datasets:
if dataset in self.config.datasets:
logging.info(f"\n======== {dataset} =======")
dataset_config = self.config.datasets[dataset]
logging.info(self._convert_node_to_json(dataset_config))
else:
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
logging.info(f"\n====== Model Attributes ======")
logging.info(self._convert_node_to_json(self.config.model))
def _convert_node_to_json(self, node):
container = OmegaConf.to_container(node, resolve=True)
return json.dumps(container, indent=4, sort_keys=True)
def to_dict(self):
return OmegaConf.to_container(self.config)
def node_to_dict(node):
return OmegaConf.to_container(node)
class ConfigValidator:
"""
This is a preliminary implementation to centralize and validate the configuration.
May be altered in the future.
A helper class to validate configurations from yaml file.
This serves the following purposes:
1. Ensure all the options in the yaml are defined, raise error if not.
2. when type mismatches are found, the validator will raise an error.
3. a central place to store and display helpful messages for supported configurations.
"""
class _Argument:
def __init__(self, name, choices=None, type=None, help=None):
self.name = name
self.val = None
self.choices = choices
self.type = type
self.help = help
def __str__(self):
s = f"{self.name}={self.val}"
if self.type is not None:
s += f", ({self.type})"
if self.choices is not None:
s += f", choices: {self.choices}"
if self.help is not None:
s += f", ({self.help})"
return s
def __init__(self, description):
self.description = description
self.arguments = dict()
self.parsed_args = None
def __getitem__(self, key):
assert self.parsed_args is not None, "No arguments parsed yet."
return self.parsed_args[key]
def __str__(self) -> str:
return self.format_help()
def add_argument(self, *args, **kwargs):
"""
Assume the first argument is the name of the argument.
"""
self.arguments[args[0]] = self._Argument(*args, **kwargs)
def validate(self, config=None):
"""
Convert yaml config (dict-like) to list, required by argparse.
"""
for k, v in config.items():
assert (
k in self.arguments
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
if self.arguments[k].type is not None:
try:
self.arguments[k].val = self.arguments[k].type(v)
except ValueError:
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
if self.arguments[k].choices is not None:
assert (
v in self.arguments[k].choices
), f"""{k} must be one of {self.arguments[k].choices}."""
return config
def format_arguments(self):
return str([f"{k}" for k in sorted(self.arguments.keys())])
def format_help(self):
# description + key-value pair string for each argument
help_msg = str(self.description)
return help_msg + ", available arguments: " + self.format_arguments()
def print_help(self):
# display help message
print(self.format_help())
def create_runner_config_validator():
validator = ConfigValidator(description="Runner configurations")
validator.add_argument(
"runner",
type=str,
choices=["runner_base", "runner_iter"],
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
runner runs based on iters. Default: runner_base""",
)
# add argumetns for training dataset ratios
validator.add_argument(
"train_dataset_ratios",
type=Dict[str, float],
help="""Ratios of training dataset. This is used in iteration-based runner.
Do not support for epoch-based runner because how to define an epoch becomes tricky.
Default: None""",
)
validator.add_argument(
"max_iters",
type=float,
help="Maximum number of iterations to run.",
)
validator.add_argument(
"max_epoch",
type=int,
help="Maximum number of epochs to run.",
)
# add arguments for iters_per_inner_epoch
validator.add_argument(
"iters_per_inner_epoch",
type=float,
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
)
lr_scheds_choices = registry.list_lr_schedulers()
validator.add_argument(
"lr_sched",
type=str,
choices=lr_scheds_choices,
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
)
task_choices = registry.list_tasks()
validator.add_argument(
"task",
type=str,
choices=task_choices,
help="Task to use, from {}".format(task_choices),
)
# add arguments for init_lr
validator.add_argument(
"init_lr",
type=float,
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
)
# add arguments for min_lr
validator.add_argument(
"min_lr",
type=float,
help="Minimum learning rate (after decay).",
)
# add arguments for warmup_lr
validator.add_argument(
"warmup_lr",
type=float,
help="Starting learning rate for warmup.",
)
# add arguments for learning rate decay rate
validator.add_argument(
"lr_decay_rate",
type=float,
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
)
# add arguments for weight decay
validator.add_argument(
"weight_decay",
type=float,
help="Weight decay rate.",
)
# add arguments for training batch size
validator.add_argument(
"batch_size_train",
type=int,
help="Training batch size.",
)
# add arguments for evaluation batch size
validator.add_argument(
"batch_size_eval",
type=int,
help="Evaluation batch size, including validation and testing.",
)
# add arguments for number of workers for data loading
validator.add_argument(
"num_workers",
help="Number of workers for data loading.",
)
# add arguments for warm up steps
validator.add_argument(
"warmup_steps",
type=int,
help="Number of warmup steps. Required if a warmup schedule is used.",
)
# add arguments for random seed
validator.add_argument(
"seed",
type=int,
help="Random seed.",
)
# add arguments for output directory
validator.add_argument(
"output_dir",
type=str,
help="Output directory to save checkpoints and logs.",
)
# add arguments for whether only use evaluation
validator.add_argument(
"evaluate",
help="Whether to only evaluate the model. If true, training will not be performed.",
)
# add arguments for splits used for training, e.g. ["train", "val"]
validator.add_argument(
"train_splits",
type=list,
help="Splits to use for training.",
)
# add arguments for splits used for validation, e.g. ["val"]
validator.add_argument(
"valid_splits",
type=list,
help="Splits to use for validation. If not provided, will skip the validation.",
)
# add arguments for splits used for testing, e.g. ["test"]
validator.add_argument(
"test_splits",
type=list,
help="Splits to use for testing. If not provided, will skip the testing.",
)
# add arguments for accumulating gradient for iterations
validator.add_argument(
"accum_grad_iters",
type=int,
help="Number of iterations to accumulate gradient for.",
)
# ====== distributed training ======
validator.add_argument(
"device",
type=str,
choices=["cpu", "cuda"],
help="Device to use. Support 'cuda' or 'cpu' as for now.",
)
validator.add_argument(
"world_size",
type=int,
help="Number of processes participating in the job.",
)
validator.add_argument("dist_url", type=str)
validator.add_argument("distributed", type=bool)
# add arguments to opt using distributed sampler during evaluation or not
validator.add_argument(
"use_dist_eval_sampler",
type=bool,
help="Whether to use distributed sampler during evaluation or not.",
)
# ====== task specific ======
# generation task specific arguments
# add arguments for maximal length of text output
validator.add_argument(
"max_len",
type=int,
help="Maximal length of text output.",
)
# add arguments for minimal length of text output
validator.add_argument(
"min_len",
type=int,
help="Minimal length of text output.",
)
# add arguments number of beams
validator.add_argument(
"num_beams",
type=int,
help="Number of beams used for beam search.",
)
# vqa task specific arguments
# add arguments for number of answer candidates
validator.add_argument(
"num_ans_candidates",
type=int,
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
)
# add arguments for inference method
validator.add_argument(
"inference_method",
type=str,
choices=["genearte", "rank"],
help="""Inference method to use for question answering. If rank, requires a answer list.""",
)
# ====== model specific ======
validator.add_argument(
"k_test",
type=int,
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
)
return validator

203
models/common/dist_utils.py Executable file
View file

@ -0,0 +1,203 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import functools
import os
import torch
import torch.distributed as dist
import timm.models.hub as timm_hub
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if args.distributed is False:
print("Not using distributed mode")
args.rank = 0
return
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
args.rank = 0
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
"| distributed init (rank {}, world {}): {}".format(
args.rank, args.world_size, args.dist_url
),
flush=True,
)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(
days=365
), # allow auto-downloading and de-compressing
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def get_dist_info():
if torch.__version__ < "1.0":
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else: # non-distributed training
rank = 0
world_size = 1
return rank, world_size
def main_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
torch.distributed.all_reduce(all_gradients)
return all_gradients[torch.distributed.get_rank()]
def all_gather_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = torch.distributed.get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
# tensor_all = GatherLayer.apply(tensors)
tensor_all = GatherLayer.apply(tensors)
return torch.cat(tensor_all, dim=0)
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
# if use distributed training
if not is_dist_avail_and_initialized():
return tensor
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output

224
models/common/eval_utils.py Normal file
View file

@ -0,0 +1,224 @@
import argparse
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
import sys
sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img')
from minigpt4.common.registry import registry
from minigpt4.common.config import Config
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
# from minigpt4.runners import *
from minigpt4.tasks import *
from pycocoevalcap.cider.cider import Cider
import os
import openai
from tqdm import tqdm
import json
import ast
import time
def eval_parser():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml")
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint")
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
return parser
def prepare_texts(texts, conv_temp, template='<Img><ImageHere></Img>', lengths=None):
convs = [conv_temp.copy() for _ in range(len(texts))]
if lengths is None:
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)]
else:
templates = [template * length for length in lengths]
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)]
[conv.append_message(conv.roles[1], None) for conv in convs]
texts = [conv.get_prompt() for conv in convs]
return texts
def init_model(args):
print('Initialization Model')
cfg = Config(args)
cfg.model_cfg.ckpt = args.ckpt
cfg.model_cfg.lora_r = args.lora_r
cfg.model_cfg.lora_alpha = args.lora_alpha
model_config = cfg.model_cfg
model_config.low_resource = True
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
# import pudb; pudb.set_trace()
key = list(cfg.datasets_cfg.keys())[0]
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
print(vis_processor_cfg)
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
print('Initialization Finished')
return model, vis_processor
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
def eval_bleu(results):
bleus1,bleus2,bleus3,bleus4 = [],[],[],[]
for result in tqdm (results,desc="bleu_eval"):
gt = result['gt']
pred = result['pred']
bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0)))
bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0)))
bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0)))
bleus4.append(sentence_bleu([gt.split()], pred.split()))
# print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True)
return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)}
# Create a Cider object
cider_scorer = Cider()
def eval_cider(pred_result,gt_result):
# Compute CIDEr scores
mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result)
cider_scores_dict={}
for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") :
assert pred_vid_id==gt_vid_id
cider_scores_dict[pred_vid_id] = score
return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict}
openai.api_key_path = "/home/ataallka/chatgpt_api.txt"
def chat_gpt_eval(results,output_path):
trial=0
gpt_results=[]
avg_chatgpt_score=0
existed_files={}
# read previous results from output path
for file in os.listdir(output_path):
if file.endswith(".json"):
with open(f'{output_path}/{file}') as json_file:
data = json.load(json_file)
gpt_results.append(data[0])
avg_chatgpt_score+=float(data[0]['chatgpt_score'])
existed_files[data[0]['video_name']]=True
length_output_path=len(os.listdir(output_path))
while len (results)!= length_output_path:
for res in tqdm(results,desc="chatgpt_eval"):
if existed_files.get(res['video_name'],False):
continue
video_name=res['video_name']
sentence_1=res['A']
sentence_2=res['pred']
try:
# prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": prompt
}],
)
res['chatgpt_score']=response.choices[0].message['content']
out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']}
gpt_results.append(out)
# save each video result in a json file
with open(f'{output_path}/{video_name}.json', 'w') as f:
json.dump([out], f)
avg_chatgpt_score+=float(response.choices[0].message['content'])
except Exception as e:
print("chat gpt error",e)
print ("Finished chat gpt evaluation in trial",trial)
trial+=1
length_output_path=len(os.listdir(output_path))
return results,avg_chatgpt_score/len(results)
def GPT4_answer(question, answer,pred):
try:
# Compute the correctness score
completion = openai.ChatCompletion.create(
# model="gpt-3.5-turbo",
model='gpt-4',
messages=[
{
"role": "system",
"content":
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
"------"
"##INSTRUCTIONS: "
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
"- Consider synonyms or paraphrases as valid matches.\n"
"- Evaluate the correctness of the prediction compared to the answer."
},
{
"role": "user",
"content":
"Please evaluate the following video-based question-answer pair:\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Predicted Answer: {pred}\n\n"
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
}
]
)
# Convert response to a Python dictionary.
response_message = completion["choices"][0]["message"]["content"]
response_dict = ast.literal_eval(response_message)
return response_dict
except Exception as e:
print(f"Error : {e}")
return None
def GPT4_evaluation(val_result):
scores=[]
yes_count=0
no_count=0
for res in val_result:
gpt_response=GPT4_answer(res['Q'],res['A'],res['pred'])
if gpt_response is None:
continue
try:
scores.append(float(gpt_response['score']))
if 'yes' in gpt_response['pred'].lower():
yes_count+=1
elif 'no' in gpt_response['pred'].lower():
no_count+=1
except:
continue
avg_score=sum(scores)/len(scores)
accuracy=(yes_count/(yes_count+no_count))*100
print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
return avg_score,accuracy
# with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f:
# results = json.load(f)
# t1=time.time()
# avg_score,accuracy=GPT4_evaluation(results)
# print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
# print(f"Time taken: {time.time()-t1}")

24
models/common/gradcam.py Executable file
View file

@ -0,0 +1,24 @@
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform
def getAttMap(img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
)
return attMap

195
models/common/logger.py Executable file
View file

@ -0,0 +1,195 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
from minigpt4.common import dist_utils
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not dist_utils.is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def global_avg(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
log_msg = [
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
if torch.cuda.is_available():
log_msg.append("max mem: {memory:.0f}")
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def setup_logger():
logging.basicConfig(
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()],
)

119
models/common/optims.py Executable file
View file

@ -0,0 +1,119 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import math
from minigpt4.common.registry import registry
@registry.register_lr_scheduler("linear_warmup_step_lr")
class LinearWarmupStepLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
min_lr,
init_lr,
decay_rate=1,
warmup_start_lr=-1,
warmup_steps=0,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.min_lr = min_lr
self.decay_rate = decay_rate
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
if cur_epoch == 0:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
step_lr_schedule(
epoch=cur_epoch,
optimizer=self.optimizer,
init_lr=self.init_lr,
min_lr=self.min_lr,
decay_rate=self.decay_rate,
)
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
class LinearWarmupCosineLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
iters_per_epoch,
min_lr,
init_lr,
warmup_steps=0,
warmup_start_lr=-1,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.iters_per_epoch = iters_per_epoch
self.min_lr = min_lr
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
if total_cur_step < self.warmup_steps:
warmup_lr_schedule(
step=total_cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
cosine_lr_schedule(
epoch=total_cur_step,
optimizer=self.optimizer,
max_epoch=self.max_epoch * self.iters_per_epoch,
init_lr=self.init_lr,
min_lr=self.min_lr,
)
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
"""Decay the learning rate"""
lr = (init_lr - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * epoch / max_epoch)
) + min_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
"""Warmup the learning rate"""
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
"""Decay the learning rate"""
lr = max(min_lr, init_lr * (decay_rate**epoch))
for param_group in optimizer.param_groups:
param_group["lr"] = lr

330
models/common/registry.py Executable file
View file

@ -0,0 +1,330 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
class Registry:
mapping = {
"builder_name_mapping": {},
"task_name_mapping": {},
"processor_name_mapping": {},
"model_name_mapping": {},
"lr_scheduler_name_mapping": {},
"runner_name_mapping": {},
"state": {},
"paths": {},
}
@classmethod
def register_builder(cls, name):
r"""Register a dataset builder to registry with key 'name'
Args:
name: Key with which the builder will be registered.
Usage:
from minigpt4.common.registry import registry
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
"""
def wrap(builder_cls):
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
assert issubclass(
builder_cls, BaseDatasetBuilder
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls
)
if name in cls.mapping["builder_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["builder_name_mapping"][name]
)
)
cls.mapping["builder_name_mapping"][name] = builder_cls
return builder_cls
return wrap
@classmethod
def register_task(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(task_cls):
from minigpt4.tasks.base_task import BaseTask
assert issubclass(
task_cls, BaseTask
), "All tasks must inherit BaseTask class"
if name in cls.mapping["task_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["task_name_mapping"][name]
)
)
cls.mapping["task_name_mapping"][name] = task_cls
return task_cls
return wrap
@classmethod
def register_model(cls, name):
r"""Register a task to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(model_cls):
# from minigpt4.models import BaseModel
# assert issubclass(
# model_cls, BaseModel
# ), "All models must inherit BaseModel class"
if name in cls.mapping["model_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["model_name_mapping"][name]
)
)
cls.mapping["model_name_mapping"][name] = model_cls
return model_cls
return wrap
@classmethod
def register_processor(cls, name):
r"""Register a processor to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(processor_cls):
from minigpt4.processors import BaseProcessor
assert issubclass(
processor_cls, BaseProcessor
), "All processors must inherit BaseProcessor class"
if name in cls.mapping["processor_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["processor_name_mapping"][name]
)
)
cls.mapping["processor_name_mapping"][name] = processor_cls
return processor_cls
return wrap
@classmethod
def register_lr_scheduler(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(lr_sched_cls):
if name in cls.mapping["lr_scheduler_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["lr_scheduler_name_mapping"][name]
)
)
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
return lr_sched_cls
return wrap
@classmethod
def register_runner(cls, name):
r"""Register a model to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from minigpt4.common.registry import registry
"""
def wrap(runner_cls):
if name in cls.mapping["runner_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["runner_name_mapping"][name]
)
)
cls.mapping["runner_name_mapping"][name] = runner_cls
return runner_cls
return wrap
@classmethod
def register_path(cls, name, path):
r"""Register a path to registry with key 'name'
Args:
name: Key with which the path will be registered.
Usage:
from minigpt4.common.registry import registry
"""
assert isinstance(path, str), "All path must be str."
if name in cls.mapping["paths"]:
raise KeyError("Name '{}' already registered.".format(name))
cls.mapping["paths"][name] = path
@classmethod
def register(cls, name, obj):
r"""Register an item to registry with key 'name'
Args:
name: Key with which the item will be registered.
Usage::
from minigpt4.common.registry import registry
registry.register("config", {})
"""
path = name.split(".")
current = cls.mapping["state"]
for part in path[:-1]:
if part not in current:
current[part] = {}
current = current[part]
current[path[-1]] = obj
# @classmethod
# def get_trainer_class(cls, name):
# return cls.mapping["trainer_name_mapping"].get(name, None)
@classmethod
def get_builder_class(cls, name):
return cls.mapping["builder_name_mapping"].get(name, None)
@classmethod
def get_model_class(cls, name):
return cls.mapping["model_name_mapping"].get(name, None)
@classmethod
def get_task_class(cls, name):
return cls.mapping["task_name_mapping"].get(name, None)
@classmethod
def get_processor_class(cls, name):
return cls.mapping["processor_name_mapping"].get(name, None)
@classmethod
def get_lr_scheduler_class(cls, name):
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
@classmethod
def get_runner_class(cls, name):
return cls.mapping["runner_name_mapping"].get(name, None)
@classmethod
def list_runners(cls):
return sorted(cls.mapping["runner_name_mapping"].keys())
@classmethod
def list_models(cls):
return sorted(cls.mapping["model_name_mapping"].keys())
@classmethod
def list_tasks(cls):
return sorted(cls.mapping["task_name_mapping"].keys())
@classmethod
def list_processors(cls):
return sorted(cls.mapping["processor_name_mapping"].keys())
@classmethod
def list_lr_schedulers(cls):
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
@classmethod
def list_datasets(cls):
return sorted(cls.mapping["builder_name_mapping"].keys())
@classmethod
def get_path(cls, name):
return cls.mapping["paths"].get(name, None)
@classmethod
def get(cls, name, default=None, no_warning=False):
r"""Get an item from registry with key 'name'
Args:
name (string): Key whose value needs to be retrieved.
default: If passed and key is not in registry, default value will
be returned with a warning. Default: None
no_warning (bool): If passed as True, warning when key doesn't exist
will not be generated. Useful for MMF's
internal operations. Default: False
"""
original_name = name
name = name.split(".")
value = cls.mapping["state"]
for subname in name:
value = value.get(subname, default)
if value is default:
break
if (
"writer" in cls.mapping["state"]
and value == default
and no_warning is False
):
cls.mapping["state"]["writer"].warning(
"Key {} is not present in registry, returning default value "
"of {}".format(original_name, default)
)
return value
@classmethod
def unregister(cls, name):
r"""Remove an item from registry with key 'name'
Args:
name: Key which needs to be removed.
Usage::
from mmf.common.registry import registry
config = registry.unregister("config")
"""
return cls.mapping["state"].pop(name, None)
registry = Registry()

424
models/common/utils.py Executable file
View file

@ -0,0 +1,424 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import io
import json
import logging
import os
import pickle
import re
import shutil
import urllib
import urllib.error
import urllib.request
from typing import Optional
from urllib.parse import urlparse
import numpy as np
import pandas as pd
import yaml
from iopath.common.download import download
from iopath.common.file_io import file_lock, g_pathmgr
from models.common.registry import registry
from torch.utils.model_zoo import tqdm
from torchvision.datasets.utils import (
check_integrity,
download_file_from_google_drive,
extract_archive,
)
def now():
from datetime import datetime
return datetime.now().strftime("%Y%m%d%H%M")
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def get_cache_path(rel_path):
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
def get_abs_path(rel_path):
return os.path.join(registry.get_path("library_root"), rel_path)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
# The following are adapted from torchvision and vissl
# torchvision: https://github.com/pytorch/vision
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
print(f"Error creating directory: {dir_path}")
return is_success
def get_redirected_url(url: str):
"""
Given a URL, returns the URL it redirects to or the
original URL in case of no indirection
"""
import requests
with requests.Session() as session:
with session.get(url, stream=True, allow_redirects=True) as response:
if response.history:
return response.url
else:
return url
def to_google_drive_download_url(view_url: str) -> str:
"""
Utility function to transform a view URL of google drive
to a download URL for google drive
Example input:
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
Example output:
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
"""
splits = view_url.split("/")
assert splits[-1] == "view"
file_id = splits[-2]
return f"https://drive.google.com/uc?export=download&id={file_id}"
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
"""
Download a file from google drive
Downloading an URL from google drive requires confirmation when
the file of the size is too big (google drive notifies that
anti-viral checks cannot be performed on such files)
"""
import requests
with requests.Session() as session:
# First get the confirmation token and append it to the URL
with session.get(url, stream=True, allow_redirects=True) as response:
for k, v in response.cookies.items():
if k.startswith("download_warning"):
url = url + "&confirm=" + v
# Then download the content of the file
with session.get(url, stream=True, verify=True) as response:
makedir(output_path)
path = os.path.join(output_path, output_file_name)
total_size = int(response.headers.get("Content-length", 0))
with open(path, "wb") as file:
from tqdm import tqdm
with tqdm(total=total_size) as progress_bar:
for block in response.iter_content(
chunk_size=io.DEFAULT_BUFFER_SIZE
):
file.write(block)
progress_bar.update(len(block))
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(
urllib.request.Request(url, headers={"User-Agent": "vissl"})
) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def download_url(
url: str,
root: str,
filename: Optional[str] = None,
md5: Optional[str] = None,
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under.
If None, use the basename of the URL.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir(root)
# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return
# expand redirect chain if needed
url = get_redirected_url(url)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def cache_url(url: str, cache_dir: str) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
parsed_url = urlparse(url)
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
makedir(dirname)
filename = url.split("/")[-1]
cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logging.info(f"Downloading {url} to {cached} ...")
cached = download(url, dirname, filename=filename)
logging.info(f"URL {url} cached in {cached}")
return cached
# TODO (prigoyal): convert this into RAII-style API
def create_file_symlink(file1, file2):
"""
Simply create the symlinks for a given file1 to file2.
Useful during model checkpointing to symlinks to the
latest successful checkpoint.
"""
try:
if g_pathmgr.exists(file2):
g_pathmgr.rm(file2)
g_pathmgr.symlink(file1, file2)
except Exception as e:
logging.info(f"Could NOT create symlink. Error: {e}")
def save_file(data, filename, append_to_json=True, verbose=True):
"""
Common i/o utility to handle saving data to various file formats.
Supported:
.pkl, .pickle, .npy, .json
Specifically for .json, users have the option to either append (default)
or rewrite by passing in Boolean value to append_to_json.
"""
if verbose:
logging.info(f"Saving data to file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "wb") as fopen:
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
elif file_ext == ".npy":
with g_pathmgr.open(filename, "wb") as fopen:
np.save(fopen, data)
elif file_ext == ".json":
if append_to_json:
with g_pathmgr.open(filename, "a") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
else:
with g_pathmgr.open(filename, "w") as fopen:
fopen.write(json.dumps(data, sort_keys=True) + "\n")
fopen.flush()
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "w") as fopen:
dump = yaml.dump(data)
fopen.write(dump)
fopen.flush()
else:
raise Exception(f"Saving {file_ext} is not supported yet")
if verbose:
logging.info(f"Saved data to file: {filename}")
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
"""
Common i/o utility to handle loading data from various file formats.
Supported:
.pkl, .pickle, .npy, .json
For the npy files, we support reading the files in mmap_mode.
If the mmap_mode of reading is not successful, we load data without the
mmap_mode.
"""
if verbose:
logging.info(f"Loading data from file: {filename}")
file_ext = os.path.splitext(filename)[1]
if file_ext == ".txt":
with g_pathmgr.open(filename, "r") as fopen:
data = fopen.readlines()
elif file_ext in [".pkl", ".pickle"]:
with g_pathmgr.open(filename, "rb") as fopen:
data = pickle.load(fopen, encoding="latin1")
elif file_ext == ".npy":
if mmap_mode:
try:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(
fopen,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
except ValueError as e:
logging.info(
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
)
data = np.load(
filename,
allow_pickle=allow_pickle,
encoding="latin1",
mmap_mode=mmap_mode,
)
logging.info("Successfully loaded without g_pathmgr")
except Exception:
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
else:
with g_pathmgr.open(filename, "rb") as fopen:
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
elif file_ext == ".json":
with g_pathmgr.open(filename, "r") as fopen:
data = json.load(fopen)
elif file_ext == ".yaml":
with g_pathmgr.open(filename, "r") as fopen:
data = yaml.load(fopen, Loader=yaml.FullLoader)
elif file_ext == ".csv":
with g_pathmgr.open(filename, "r") as fopen:
data = pd.read_csv(fopen)
else:
raise Exception(f"Reading from {file_ext} is not supported yet")
return data
def abspath(resource_path: str):
"""
Make a path absolute, but take into account prefixes like
"http://" or "manifold://"
"""
regex = re.compile(r"^\w+://")
if regex.match(resource_path) is None:
return os.path.abspath(resource_path)
else:
return resource_path
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
logging.info(f"Error creating directory: {dir_path}")
return is_success
def is_url(input_url):
"""
Check if an input string is a url. look for http(s):// and ignoring the case
"""
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
return is_url
def cleanup_dir(dir):
"""
Utility for deleting a directory. Useful for cleaning the storage space
that contains various training artifacts like checkpoints, data etc.
"""
if os.path.exists(dir):
logging.info(f"Deleting directory: {dir}")
shutil.rmtree(dir)
logging.info(f"Deleted contents of directory: {dir}")
def get_file_size(filename):
"""
Given a file, get the size of file in MB
"""
size_in_mb = os.path.getsize(filename) / float(1024**2)
return size_in_mb

View file

@ -0,0 +1,89 @@
# coding: utf-8
import sys
dataDir = '../../VQA'
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
from vqa import VQA
from vqaEvaluation.vqaEval import VQAEval
import matplotlib.pyplot as plt
import skimage.io as io
import json
import random
import os
# set up file names and paths
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
resultType ='fake'
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
# An example result json file has been provided in './Results' folder.
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
resultType, fileType) for fileType in fileTypes]
# create vqa object and vqaRes object
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(resFile, quesFile)
# create vqaEval object by taking vqa and vqaRes
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
# evaluate results
"""
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
By default it uses all the question ids in annotation file
"""
vqaEval.evaluate()
# print accuracies
print "\n"
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
print "Per Question Type Accuracy is the following:"
for quesType in vqaEval.accuracy['perQuestionType']:
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
print "\n"
print "Per Answer Type Accuracy is the following:"
for ansType in vqaEval.accuracy['perAnswerType']:
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
print "\n"
# demo how to use evalQA to retrieve low score result
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
if len(evals) > 0:
print 'ground truth answers'
randomEval = random.choice(evals)
randomAnn = vqa.loadQA(randomEval)
vqa.showQA(randomAnn)
print '\n'
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
ann = vqaRes.loadQA(randomEval)[0]
print "Answer: %s\n" %(ann['answer'])
imgId = randomAnn[0]['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# plot accuracy for various question types
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
plt.title('Per Question Type Accuracy', fontsize=10)
plt.xlabel('Question Types', fontsize=10)
plt.ylabel('Accuracy', fontsize=10)
plt.show()
# save evaluation results to ./Results folder
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))

View file

@ -0,0 +1 @@
author='aagrawal'

View file

@ -0,0 +1,192 @@
# coding=utf-8
__author__='aagrawal'
import re
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
class VQAEval:
def __init__(self, vqa, vqaRes, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
self.params = {'question_id': vqa.getQuesIds()}
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
"youll": "you'll", "youre": "you're", "youve": "you've"}
self.manualMap = { 'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10'
}
self.articles = ['a',
'an',
'the'
]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [';', r"/", '[', ']', '"', '{', '}',
'(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!']
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params['question_id']]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
# print "computing accuracy"
step = 0
for quesId in quesIds:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
ansDic['answer'] = ansDic['answer'].strip()
resAns = res[quesId]['answer']
resAns = resAns.replace('\n', ' ')
resAns = resAns.replace('\t', ' ')
resAns = resAns.strip()
gtAcc = []
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
for gtAnsDatum in gts[quesId]['answers']:
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
acc = min(1, float(len(matchingAns))/3)
gtAcc.append(acc)
quesType = gts[quesId]['question_type']
ansType = gts[quesId]['answer_type']
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step%100 == 0:
self.updateProgress(step/float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
# print "Done computing accuracy"
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = self.periodStrip.sub("",
outText,
re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = ' '.join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100*acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength*progress))
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
sys.stdout.write(text)
sys.stdout.flush()

View file

@ -0,0 +1,73 @@
# coding: utf-8
from vqaTools.vqa import VQA
import random
import skimage.io as io
import matplotlib.pyplot as plt
import os
dataDir ='../../VQA'
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
# initialize VQA api for QA annotations
vqa=VQA(annFile, quesFile)
# load and display QA annotations for given question types
"""
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
"""
annIds = vqa.getQuesIds(quesTypes='how many');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given answer types
"""
ansTypes can be one of the following
yes/no
number
other
"""
annIds = vqa.getQuesIds(ansTypes='yes/no');
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()
# load and display QA annotations for given images
"""
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
"""
ids = vqa.getImgIds()
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
I = io.imread(imgDir + imgFilename)
plt.imshow(I)
plt.axis('off')
plt.show()

View file

@ -0,0 +1 @@
__author__ = 'aagrawal'

View file

@ -0,0 +1,179 @@
__author__ = 'aagrawal'
__version__ = '0.9'
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
# print 'loading VQA annotations and questions into memory...'
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, 'r'))
questions = json.load(open(question_file, 'r'))
# print datetime.datetime.utcnow() - time_t
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
for ann in self.dataset['annotations']:
imgToQA[ann['image_id']] += [ann]
qa[ann['question_id']] = ann
for ques in self.questions['questions']:
qqa[ques['question_id']] = ques
# print 'index created!'
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
# for key, value in self.datset['info'].items():
# print '%s: %s'%(key, value)
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['question_id'] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(quesIds) == 0:
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
else:
anns = self.dataset['annotations']
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
ids = [ann['image_id'] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann['question_id']
print("Question: %s" % (self.qqa[quesId]['question']))
for ans in ann['answers']:
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset['info'] = copy.deepcopy(self.questions['info'])
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
res.dataset['license'] = copy.deepcopy(self.questions['license'])
# print 'Loading and preparing results... '
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, 'results is not an array of objects'
annsQuesIds = [ann['question_id'] for ann in anns]
assert set(annsQuesIds) == set(self.getQuesIds()), \
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
for ann in anns:
quesId = ann['question_id']
if res.dataset['task_type'] == 'Multiple Choice':
assert ann['answer'] in self.qqa[quesId][
'multiple_choices'], 'predicted answer is not one of the multiple choices'
qaAnn = self.qa[quesId]
ann['image_id'] = qaAnn['image_id']
ann['question_type'] = qaAnn['question_type']
ann['answer_type'] = qaAnn['answer_type']
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
res.dataset['annotations'] = anns
res.createIndex()
return res

View file

@ -0,0 +1,80 @@
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
===================
## VQA v2.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
There is only one type of task
- Open-ended task
## VQA v1.0 release ##
This release consists of
- Real
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
- Abstract
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
- 600,000 answers for training and 300,000 answers for validation (10 per question)
There are two types of tasks
- Open-ended task
- Multiple-choice task (18 choices per question)
## Requirements ##
- python 2.7
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
## Files ##
./Questions
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
./Annotations
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
./Images
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
./PythonHelperTools
- This directory contains the Python API to read and visualize the VQA dataset
- vqaDemo.py (demo script)
- vqaTools (API to read and visualize data)
./PythonEvaluationTools
- This directory contains the Python evaluation code
- vqaEvalDemo.py (evaluation demo script)
- vqaEvaluation (evaluation code)
./Results
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
./QuestionTypes
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
- mscoco_question_types.txt
- abstract_v002_question_types.txt
## References ##
- [VQA: Visual Question Answering](http://visualqa.org/)
- [Microsoft COCO](http://mscoco.org/)
## Developers ##
- Aishwarya Agrawal (Virginia Tech)
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).

View file

@ -0,0 +1,8 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"

View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2022 Allen Institute for Artificial Intelligence
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,207 @@
# A-OKVQA
Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**.
Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public)
### Abstract
The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art visionlanguage models.
![dataset_web](https://user-images.githubusercontent.com/28768645/170799740-f0d9ea60-6aff-4322-98d5-cae8e05983f4.svg)
<hr>
#### Table of Contents
- [Getting started](#getting-started)
* [Downloading the dataset](#downloading-the-dataset)
- [Evaluation & Leaderboard](#evaluation)
- [Codebase](#codebase)
* [Preparing data](#preparing-data)
* [Models and Predictions](#models-and-predictions)
<hr>
## Getting started
```bash
git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git
cd aokvqa
export PYTHONPATH=.
conda env create --name aokvqa
conda activate aokvqa
```
### Downloading the dataset
```bash
export AOKVQA_DIR=./datasets/aokvqa/
mkdir -p ${AOKVQA_DIR}
curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
```
<details> <summary><b>Downloading COCO 2017</b></summary>
```bash
export COCO_DIR=./datasets/coco/
mkdir -p ${COCO_DIR}
for split in train val test; do
wget "http://images.cocodataset.org/zips/${split}2017.zip"
unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip"
done
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip
```
</details>
Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code.
```python
import os
aokvqa_dir = os.getenv('AOKVQA_DIR')
from load_aokvqa import load_aokvqa, get_coco_path
train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test'
```
<details> <summary><b>Example dataset entry</b></summary>
```python
dataset_example = train_dataset[0]
print(dataset_example['question_id'])
# 22MexNkBPpdZGX6sxbxVBH
coco_dir = os.getenv('COCO_DIR')
image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)
print(image_path)
# ./datasets/coco/train2017/000000299207.jpg
print(dataset_example['question'])
print(dataset_example['choices'])
# What is the man by the bags awaiting?
# ['skateboarder', 'train', 'delivery', 'cab']
correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]
# Corrrect: cab
print(dataset_example['rationales'][0])
# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.
```
</details>
## Evaluation
Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting.
```python
{
'<question_id>' : {
'multiple_choice' : '<prediction>',
'direct_answer' : '<prediction>'
}
}
```
You can run evaluation on the validation set as follows.
```bash
python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json
```
### Leaderboard
You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started).
## Codebase
We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3).
### Preparing data
```bash
export FEATURES_DIR=./features/
mkdir -p ${FEATURES_DIR}
```
You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments.
```bash
python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
for split in train val test; do
python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt
done
```
<details> <summary><b>For training ClipCap with a transformer mapping network</b></summary>
If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`.
</details>
<details> <summary><b>For ResNet and BERT input features</b></summary>
Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands:
```bash
# ResNet
for split in train val test; do
python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt
done
# BERT
for split in train val test; do
python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt
done
```
</details>
### Models and Predictions
```bash
export LOG_DIR=./logs/
export PREDS_DIR=./predictions/
export PT_MODEL_DIR=./pretrained_models/
mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR}
```
<details> <summary><b>Download our pretrained model weights</b></summary>
```bash
# Checkpoints for transfer learning experiments
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
# Checkpoints for ClipCap models (generating answers and rationales)
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
```
</details>
We have included instructions for replicating each of our experiments (see README.md files below).
All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above.
- [Heuristics](./heuristics/README.md)
- [Transfer Learning Experiments](./transfer_experiments/README.md)
- [Querying GPT-3](./gpt3/README.md)
- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set.
We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.)
```bash
python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json
# repeat for test split ...
```

View file

@ -0,0 +1,45 @@
import os
import argparse
from collections import Counter
import pathlib
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
args = parser.parse_args()
# Build vocab from train set: correct choices + (direct answers appearing in >= 3 )
train_set = load_aokvqa(args.aokvqa_dir, 'train')
vocab = []
all_choices = Counter()
direct_answers = Counter()
for i in train_set:
vocab.append( i['choices'][i['correct_choice_idx']] )
all_choices.update(i['choices'])
direct_answers.update(set(i['direct_answers']))
vocab += [k for k,v in all_choices.items() if v >= 3]
vocab += [k for k,v in direct_answers.items() if v >= 3]
vocab = sorted(set(vocab))
print(f"Vocab size: {len(vocab)}")
# Save vocabulary Output
with open(args.output_file, 'w') as f:
for v in vocab:
print(v, file=f)
## Check validation set coverage
val_set = load_aokvqa(args.aokvqa_dir, 'val')
val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set]
val_acc = sum(val_acc) / len(val_acc) * 100
print(f"Val set coverage: {val_acc:.2f}" )

View file

@ -0,0 +1,26 @@
import json
from tqdm import tqdm
import argparse
import pathlib
import torch
import clip
parser = argparse.ArgumentParser()
parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file')
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
args = parser.parse_args()
assert args.output_file.suffix == '.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(args.model_type, device=device)
with torch.no_grad():
a = open(args.vocab_file).read().splitlines()
mc_text = clip.tokenize(a).to(device)
mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0]
mc_text_features = mc_text_features.float()
model_name = args.model_type.replace('/', '-').replace('@', '-')
torch.save(mc_text_features, args.output_file)

View file

@ -0,0 +1,50 @@
import os
import argparse
import pathlib
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
args = parser.parse_args()
assert args.output_file.suffix == '.pt'
## Load dataset
dataset = load_aokvqa(args.aokvqa_dir, args.split)
## Load model
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
## Encoding loop
with torch.no_grad():
embeddings = {}
for d in tqdm(dataset):
encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt')
encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
e = mean_pooling(model(**encoded_input), encoded_input['attention_mask'])
embeddings[d['question_id']] = {
'question' : e[0].cpu()
}
torch.save(embeddings, args.output_file)

View file

@ -0,0 +1,51 @@
import os
from PIL import Image
from tqdm import tqdm
import argparse
import pathlib
import torch
import clip
from load_aokvqa import load_aokvqa, get_coco_path
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
args = parser.parse_args()
assert args.output_file.suffix == '.pt'
## Load dataset
dataset = load_aokvqa(args.aokvqa_dir, args.split)
## Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(args.model_type, device=device)
## Encoding loop
with torch.no_grad():
embeddings = {}
for d in tqdm(dataset):
q = d["question"]
q_text = clip.tokenize(q).to(device)
q_text_features = model.encode_text(q_text)
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir))
img = preprocess(img).unsqueeze(0).to(device)
image_features = model.encode_image(img)
embeddings[d['question_id']] = {
'question' : q_text_features[0].float().cpu(),
'image' : image_features[0].float().cpu(),
}
torch.save(embeddings, args.output_file)

View file

@ -0,0 +1,62 @@
import os
import argparse
import pathlib
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms as T
from load_aokvqa import load_aokvqa, get_coco_path
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
args = parser.parse_args()
assert args.output_file.suffix == '.pt'
## Load dataset
dataset = load_aokvqa(args.aokvqa_dir, args.split)
## Load model
resnet_preprocess = T.Compose([
T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(size=(224, 224)),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
device = "cuda" if torch.cuda.is_available() else "cpu"
resnet_model = models.resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(
*list(resnet_model.children())[:-1],
nn.Flatten()
) # strip classification layer
resnet_model = resnet_model.to(device)
## Encoding loop
with torch.no_grad():
embeddings = {}
for d in tqdm(dataset):
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB')
resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)
resnet_features = resnet_model(resnet_input)
embeddings[d['question_id']] = {
'image' : resnet_features[0].cpu()
}
torch.save(embeddings, args.output_file)

View file

@ -0,0 +1,36 @@
name: aokvqa
channels:
- pytorch
- nvidia
- huggingface
- conda-forge
- defaults
dependencies:
- python=3.7
- cudatoolkit=11.3
- numpy=1.21.6
- pytorch=1.11.0
- torchvision=0.12.0
- pytorch-lightning=1.6.3
- torchmetrics=0.8.1
- gdown=4.4.0
- pip=22.0.4
- pip:
- argparse==1.4.0
- Pillow==9.0.1
- tensorboard==2.9.0
- ftfy==6.1.1
- regex==2022.3.15
- tqdm==4.64.0
- clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
- openai==0.18.1
- nltk==3.7
- sacrebleu==2.0.0
- sacremoses==0.0.53
- sentence-transformers==2.2.0
- datasets==2.1.0
- tokenizers==0.10.3
- transformers==4.10.3
# Next: resolve conflict between sentence-transfomers and pytorch-lightning
# pip uninstall sentencepiece

View file

@ -0,0 +1,97 @@
import argparse
import pathlib
import json
import glob
from load_aokvqa import load_aokvqa
def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):
if isinstance(dataset, list):
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
if multiple_choice is False:
dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
if strict:
dataset_qids = set(dataset.keys())
preds_qids = set(preds.keys())
assert dataset_qids.issubset(preds_qids)
# dataset = q_id (str) : dataset element (dict)
# preds = q_id (str) : prediction (str)
acc = []
for q in dataset.keys():
if q not in preds.keys():
acc.append(0.0)
continue
pred = preds[q]
choices = dataset[q]['choices']
direct_answers = dataset[q]['direct_answers']
## Multiple Choice setting
if multiple_choice:
if strict:
assert pred in choices, 'Prediction must be a valid choice'
correct_choice_idx = dataset[q]['correct_choice_idx']
acc.append( float(pred == choices[correct_choice_idx]) )
## Direct Answer setting
else:
num_match = sum([pred.lower() == da.lower() for da in direct_answers])
vqa_acc = min(1.0, num_match / 3.0)
acc.append(vqa_acc)
acc = sum(acc) / len(acc) * 100
return acc
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
args = parser.parse_args()
dataset = load_aokvqa(args.aokvqa_dir, args.split)
for prediction_file in glob.glob(args.prediction_files):
predictions = json.load(open(prediction_file, 'r'))
# Multiple choice
mc_predictions = {}
for q in predictions.keys():
if 'multiple_choice' in predictions[q].keys():
mc_predictions[q] = predictions[q]['multiple_choice']
if mc_predictions != {}:
mc_acc = eval_aokvqa(
dataset,
mc_predictions,
multiple_choice=True,
strict=False
)
print(prediction_file, 'MC', mc_acc)
# Direct Answer
da_predictions = {}
for q in predictions.keys():
if 'direct_answer' in predictions[q].keys():
da_predictions[q] = predictions[q]['direct_answer']
if da_predictions != {}:
da_acc = eval_aokvqa(
dataset,
da_predictions,
multiple_choice=False,
strict=False
)
print(prediction_file, 'DA', da_acc)

View file

@ -0,0 +1,13 @@
import os
import json
def load_aokvqa(aokvqa_dir, split, version='v1p0'):
assert split in ['train', 'val', 'test', 'test_w_ans']
dataset = json.load(open(
os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
))
return dataset
def get_coco_path(split, image_id, coco_dir):
return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")

View file

@ -0,0 +1,31 @@
import argparse
import pathlib
import json
from load_aokvqa import load_aokvqa
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file')
parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file')
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
args = parser.parse_args()
assert args.mc_pred_file or args.da_pred_file
dataset = load_aokvqa(args.aokvqa_dir, args.split)
mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None
da_preds = json.load(args.da_pred_file) if args.da_pred_file else None
predictions = {}
for d in dataset:
q = d['question_id']
predictions[q] = {}
if mc_preds and q in mc_preds.keys():
predictions[q]['multiple_choice'] = mc_preds[q]
if da_preds and q in da_preds.keys():
predictions[q]['direct_answer'] = da_preds[q]
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,44 @@
import argparse
import pathlib
import json
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from load_aokvqa import load_aokvqa
def map_to_choices(dataset, predictions, device='cpu'):
if isinstance(dataset, list):
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
return predictions
model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
model.to(device)
for q in tqdm(predictions.keys()):
choices = dataset[q]['choices']
if predictions[q] not in choices:
choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
predictions[q] = choices[a_idx]
return predictions
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
dataset = load_aokvqa(args.aokvqa_dir, args.split)
predictions = json.load(args.prediction_file)
predictions = map_to_choices(dataset, predictions)
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,14 @@
## Querying GPT-3
To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables.
```bash
export OPENAI_ORG=....
export OPENAI_API_KEY=...
```
For producing predictions for both DA and MC settings, run:
```bash
python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json
python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json
```

View file

@ -0,0 +1,23 @@
import os
import json
import argparse
import pathlib
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
parser.add_argument('--split', type=str, choices=['train', 'val'], required=True)
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations']
coco_captions = {c['image_id'] : c['caption'] for c in coco_captions}
captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set }
json.dump(captions, args.output_file)

View file

@ -0,0 +1,79 @@
import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib
import openai
openai.organization = os.getenv('OPENAI_ORG')
openai.api_key = os.getenv('OPENAI_API_KEY')
from load_aokvqa import load_aokvqa
random.seed(0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--n', type=int, default=10, dest='num_examples')
parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
parser.add_argument('--include-choices', action='store_true', dest='include_choices')
parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
train_set = load_aokvqa(args.aokvqa_dir, 'train')
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
train_context = {}
context = {}
if args.context_file is not None:
train_context = json.load(args.train_context_file)
context = json.load(args.context_file)
predictions = {}
for d in tqdm(eval_set):
q = d['question_id']
prompt = args.prompt_prefix
for e in random.sample(train_set, args.num_examples):
prompt += prompt_element(e,
context=train_context.get(q, None),
include_choices=args.include_choices,
answer=True
)
prompt += '\n\n'
prompt += prompt_element(d,
context=context.get(q, None),
include_choices=args.include_choices,
answer=False
)
response = openai.Completion.create(
engine="text-curie-001",
prompt=prompt,
temperature=0.0,
max_tokens=10,
)
predictions[q] = response.choices[0].text.strip()
json.dump(predictions, args.output_file)
def prompt_element(d, context=None, include_choices=False, answer=False):
return (f"Context: {context}\n" if context is not None else '') + \
f"Q: {d['question']}\n" + \
(f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
if __name__ == '__main__':
main()

View file

@ -0,0 +1,16 @@
import json
import argparse
import pathlib
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set}
json.dump(rationales, args.output_file)

View file

@ -0,0 +1,11 @@
## Heuristics
```bash
# These scripts accept the same arguments.
# heuristics/random_unweighted.py
# heuristics/random_weighted.py
# heuristics/most_common_answer.py
python heuristics/random_unweighted.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc --out ${PREDS_DIR}/random-unweighted_val-mc.json
# Exclude --mc for the direct answer setting
```

View file

@ -0,0 +1,39 @@
import os
import json
import argparse
import pathlib
from collections import Counter
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
train_set = load_aokvqa(args.aokvqa_dir, 'train')
train_freq = dict(Counter(
[d['choices'][d['correct_choice_idx']] for d in train_set]
))
most_common_answer = max(train_freq.keys(), key=train_freq.get)
##
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
predictions = {}
for d in eval_set:
q = d['question_id']
predictions[q] = most_common_answer
if args.multiple_choice:
choices = [c for c in d['choices'] if c in train_freq.keys()]
if len(choices) > 0:
predictions[q] = max(choices, key=train_freq.get)
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,38 @@
import os
import json
from random import seed, sample
import argparse
import pathlib
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
seed(0)
train_set = load_aokvqa(args.aokvqa_dir, 'train')
if args.multiple_choice is False:
choices = list(set(
[d['choices'][d['correct_choice_idx']] for d in train_set]
))
##
predictions = {}
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
for d in eval_set:
q = d['question_id']
if args.multiple_choice:
choices = d['choices']
predictions[q] = sample(choices, 1)[0]
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,46 @@
import os
import json
import numpy as np
import argparse
import pathlib
from collections import Counter
from load_aokvqa import load_aokvqa
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
args = parser.parse_args()
np.random.seed(0)
train_set = load_aokvqa(args.aokvqa_dir, 'train')
train_freq = dict(Counter(
[d['choices'][d['correct_choice_idx']] for d in train_set]
))
if args.multiple_choice is False:
choices = list(train_freq.keys())
probs = [f / len(train_set) for f in train_freq.values()]
##
predictions = {}
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
for d in eval_set:
if args.multiple_choice:
choices = d['choices']
probs = [train_freq.get(c, 0) for c in choices]
if probs == [0, 0, 0, 0]:
probs = [1, 1, 1, 1]
probs = [p / sum(probs) for p in probs]
q = d['question_id']
predictions[q] = np.random.choice(choices, size=1, p=probs)[0]
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,13 @@
import os
import json
def load_aokvqa(aokvqa_dir, split, version='v1p0'):
assert split in ['train', 'val', 'test', 'test_w_ans']
dataset = json.load(open(
os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
))
return dataset
def get_coco_path(split, image_id, coco_dir):
return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")

View file

@ -0,0 +1,41 @@
## Transfer Learning Experiments
We use the following training/prediction scripts for the classifier, zero-shot, and contrastive experiments in Table 3.
```bash
## Training
python transfer_experiments/train.py --aokvqa-dir ${AOKVQA_DIR} --vocab ${AOKVQA_DIR}/large_vocab_train.csv --log-dir ${LOG_DIR}
--backbone clip --clip-model-type ViT-B/32 --train-features ${FEATURES_DIR}/clip-ViT-B-32_train.pt --val-features ${FEATURES_DIR}/clip-ViT-B-32_val.pt
--inputs question # OR --inputs image # OR --inputs question image
# OR
--backbone resnet --train-features ${FEATURES_DIR}/resnet_train.pt --val-features ${FEATURES_DIR}/resnet_val.pt --inputs image
# OR
--backbone bert --train-features ${FEATURES_DIR}/bert_train.pt --val-features ${FEATURES_DIR}/bert_val.pt --inputs question
--objective classifier
# OR
--objective contrastive --vocab-features ${FEATURE_DIR}/clip-ViT-B-32_large_vocab.pt
```
You can make predictions for CLIP zero-shot or from a classifier/contrastive checkpoint trained above.
```bash
## Predicting
python transfer_experiments/predict.py --aokvqa-dir ${AOKVQA_DIR} --out ${PREDS_DIR}/clip-classifier_val-mc.json
--split val # or test
--features ${FEATURE_DIR}/clip-ViT-B-32_val.pt # adjust for backbone and eval split
--ckpt path/to/model.ckpt
# OR
--zero-shot --clip-model-type ViT-B/32
--inputs question # OR --inputs image # OR --inputs question image
--mc # Multiple-choice. Exclude for direct-answer.
# IF classifier OR direct-answer
--vocab ${AOKVQA_DIR}/large_vocab_train.csv
# IF contrastive/zero-shot AND direct-answer
--vocab-features ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
```

View file

@ -0,0 +1,126 @@
import sys
import os
import argparse
import pathlib
from tqdm import tqdm
import json
import torch
import torch.nn as nn
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
import sentencepiece; import pytorch_lightning as pl; import clip
from transfer_experiments.train import LinearClassifier
from load_aokvqa import load_aokvqa
from evaluation.remap_predictions import map_to_choices
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--features', type=pathlib.Path, required=True)
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
#
parser_weights = parser.add_mutually_exclusive_group(required=True)
parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')
parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
#
parser.add_argument('--vocab', type=argparse.FileType('r'))
parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
parser.add_argument('--clip-model-type', type=str,
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
#
args = parser.parse_args()
## Load dataset
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
## Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.checkpoint_path is not None:
classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
classifier.to(device)
hp = classifier.hparams
elif args.clip_zero_shot:
classifier = nn.Identity().to(device)
hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)
# Load input features
embeddings = torch.load(args.features)
if hp.backbone == 'clip':
for q in embeddings.keys():
embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)
# Load vocab, vocab features, clip
if (hp.objective == 'classifier') or \
(hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
vocab = args.vocab.read().splitlines()
if hp.objective in ['contrastive', 'zero-shot']:
if args.multiple_choice is False:
vocab_features = torch.load(args.vocab_features).cpu()
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
else:
clip_model = clip.load(hp.clip_model_type, device=device)[0]
logit_scale = clip_model.logit_scale.exp().cpu()
## Prediction loop
predictions = {}
with torch.no_grad():
for o in tqdm(aokvqa_set):
q = o['question_id']
# Load input embedding (from question / image)
if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
e = embeddings[q]['question'] + embeddings[q]['image']
elif 'question' in hp.inputs and 'image' in hp.inputs:
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
elif 'question' in hp.inputs:
e = embeddings[q]['question']
elif 'image' in hp.inputs:
e = embeddings[q]['image']
# Pass inputs through model
e = e.unsqueeze(0).to(device)
x = classifier(e)[0].cpu()
# Predict
if hp.objective in ['contrastive', 'zero-shot']:
if args.multiple_choice:
vocab = o['choices']
# Encode choices
vocab_features = clip.tokenize(vocab).to(device)
vocab_features = torch.stack([
clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
], dim=1)[0]
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
vocab_features = vocab_features.float().cpu()
x = logit_scale * x @ vocab_features.t()
x = x.softmax(dim=-1)
predictions[q] = vocab[x.argmax().item()]
## Save and evaluate predictions
# Map prediction to nearest neighbor choice (by word embeddings)
if args.multiple_choice and hp.objective == 'classifier':
predictions = map_to_choices(aokvqa_set, predictions)
json.dump(predictions, args.output_file)

View file

@ -0,0 +1,263 @@
import os
import sys
import json
import argparse
import pathlib
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
import sentencepiece; import pytorch_lightning as pl
import torchmetrics.functional as MF
from load_aokvqa import load_aokvqa
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--vocab', type=argparse.FileType('r'), required=True)
parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True)
#
parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True)
parser.add_argument('--clip-model-type', type=str,
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
dest='clip_model_type', required=('clip' in sys.argv))
parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features')
parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features')
parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features')
#
parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True)
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True)
# Defaults
parser.add_argument('--bs', type=int, default=128, dest='batch_size')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--gpus', type=int, default=1)
args = parser.parse_args()
pl.seed_everything(1)
vocab = args.vocab.read().splitlines()
## Data loading
dm = AokvqaEmbeddingsDataModule(
args.aokvqa_dir,
args.train_features,
args.val_features,
args.objective,
args.backbone,
args.inputs,
vocab,
args.vocab_features,
batch_size=args.batch_size,
num_workers=16
)
## Model definition
model = LinearClassifier(
args.objective,
args.backbone,
args.clip_model_type,
args.inputs,
len(vocab),
args.lr
)
## Training and testing loops
logger = pl.loggers.TensorBoardLogger(
args.log_dir,
name=f'{args.backbone}-{args.objective}',
version=f"inputs:{'+'.join(args.inputs)}"
)
trainer = pl.Trainer(
logger=logger,
gpus=args.gpus,
max_epochs=args.epochs,
callbacks=[
pl.callbacks.ModelCheckpoint(
monitor="val_acc",
filename="{epoch:02d}-{val_acc:.2f}",
mode="max"
)
],
)
trainer.fit(model, dm)
class AokvqaEmbeddingsDataset(Dataset):
def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features):
aokvqa_set = load_aokvqa(aokvqa_dir, split)
assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \
or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \
or ( backbone == 'clip' )
embeddings = torch.load(input_features)
if backbone == 'clip':
for q in embeddings.keys():
embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True)
embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True)
if objective == 'contrastive':
vocab_embeddings = torch.load(vocab_features)
vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)
self.objective = objective
self.vocab_len = len(vocab)
self.embeddings = []
self.answers = []
for o in aokvqa_set:
correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers'])
correct_answers = [vocab.index(a) for a in correct_answers if a in vocab]
if self.objective == 'contrastive':
correct_answers = [vocab_embeddings[a] for a in correct_answers]
if len(correct_answers) == 0: continue
self.answers.append(correct_answers)
q = o['question_id']
if 'question' in inputs and 'image' in inputs:
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
elif 'question' in inputs and 'image' not in inputs:
e = embeddings[q]['question']
elif 'question' not in inputs and 'image' in inputs:
e = embeddings[q]['image']
self.embeddings.append(e)
def __getitem__(self, index):
e = self.embeddings[index]
a = self.answers[index]
if self.objective == 'classifier':
a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0)
elif self.objective == 'contrastive':
a = random.sample(a, 1)[0]
return e, a
def __len__(self):
return len(self.embeddings)
class AokvqaEmbeddingsDataModule(pl.LightningDataModule):
def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0):
super().__init__()
self.aokvqa_dir = aokvqa_dir
self.train_features = train_features
self.val_features = val_features
self.objective = objective
self.backbone = backbone
self.inputs = inputs
self.vocab = vocab
self.vocab_features = vocab_features
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage=None):
self.train_dataset = AokvqaEmbeddingsDataset(
self.aokvqa_dir, 'train', self.train_features, self.objective,
self.backbone, self.inputs, self.vocab, self.vocab_features
)
self.val_dataset = AokvqaEmbeddingsDataset(
self.aokvqa_dir, 'val', self.val_features, self.objective,
self.backbone, self.inputs, self.vocab, self.vocab_features
)
def train_dataloader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=int(0.8 * self.num_workers)
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, shuffle=False,
num_workers=int(0.2 * self.num_workers)
)
class LinearClassifier(pl.LightningModule):
def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001):
super().__init__()
self.save_hyperparameters(ignore=['lr'])
self.lr = lr
if self.hparams.backbone == 'clip':
clip_dim = {
'RN50' : 1024,
'RN50x4' : 640,
'RN50x16' : 768,
'RN50x64' : 1024,
'RN101' : 512,
'ViT-B/32' : 512,
'ViT-B/16' : 512,
'ViT-L/14' : 768,
'ViT-L/14@336px' : 768,
}[clip_model_type]
emb_dim = clip_dim * len(inputs)
elif self.hparams.backbone == 'resnet':
emb_dim = 2048
elif self.hparams.backbone == 'bert':
emb_dim = 768
if self.hparams.objective == 'classifier':
out_dim = vocab_len
elif self.hparams.objective == 'contrastive':
out_dim = clip_dim
self.linear = nn.Linear(emb_dim, out_dim)
def forward(self, x):
x = self.linear(x)
if self.hparams.objective == 'classifier':
x = torch.sigmoid(x)
return x
def compute_loss(self, batch):
x, y = batch
y_pred = self.forward(x)
if self.hparams.objective == 'classifier':
loss = F.binary_cross_entropy(y_pred, y.float())
elif self.hparams.objective == 'contrastive':
indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device)
sim = (y_pred @ y.T).softmax(dim=-1)
loss = F.cross_entropy(sim, indices)
if self.hparams.objective == 'classifier':
acc = MF.f1_score(y_pred, y)
elif self.hparams.objective == 'contrastive':
acc = torch.mean(sim[indices, indices])
return loss, acc
def training_step(self, batch, batch_idx):
loss, acc = self.compute_loss(batch)
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
loss, acc = self.compute_loss(batch)
self.log("val_loss", loss)
self.log("val_acc", acc)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if __name__ == '__main__':
main()

View file

@ -0,0 +1,211 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = "aagrawal"
__version__ = "0.9"
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import json
import datetime
import copy
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""
Constructor of VQA helper class for reading and visualizing questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
print("loading VQA annotations and questions into memory...")
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, "r"))
questions = json.load(open(question_file, "r"))
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
# create index
print("creating index...")
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
for ann in self.dataset["annotations"]:
imgToQA[ann["image_id"]] += [ann]
qa[ann["question_id"]] = ann
for ques in self.questions["questions"]:
qqa[ques["question_id"]] = ques
print("index created!")
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""
Print information about the VQA annotation file.
:return:
"""
for key, value in self.datset["info"].items():
print("%s: %s" % (key, value))
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""
Get question ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(imgIds) == 0:
anns = sum(
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
[],
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["question_id"] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""
Get image ids that satisfy given filter conditions. default skips that filter
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset["annotations"]
else:
if not len(quesIds) == 0:
anns = sum(
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
)
else:
anns = self.dataset["annotations"]
anns = (
anns
if len(quesTypes) == 0
else [ann for ann in anns if ann["question_type"] in quesTypes]
)
anns = (
anns
if len(ansTypes) == 0
else [ann for ann in anns if ann["answer_type"] in ansTypes]
)
ids = [ann["image_id"] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""
Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann["question_id"]
print("Question: %s" % (self.qqa[quesId]["question"]))
for ans in ann["answers"]:
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
def loadRes(self, resFile, quesFile):
"""
Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset["info"] = copy.deepcopy(self.questions["info"])
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
res.dataset["license"] = copy.deepcopy(self.questions["license"])
print("Loading and preparing results... ")
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, "results is not an array of objects"
annsQuesIds = [ann["question_id"] for ann in anns]
assert set(annsQuesIds) == set(
self.getQuesIds()
), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
for ann in anns:
quesId = ann["question_id"]
if res.dataset["task_type"] == "Multiple Choice":
assert (
ann["answer"] in self.qqa[quesId]["multiple_choices"]
), "predicted answer is not one of the multiple choices"
qaAnn = self.qa[quesId]
ann["image_id"] = qaAnn["image_id"]
ann["question_type"] = qaAnn["question_type"]
ann["answer_type"] = qaAnn["answer_type"]
print(
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
)
res.dataset["annotations"] = anns
res.createIndex()
return res

View file

@ -0,0 +1,324 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
# coding=utf-8
__author__ = "aagrawal"
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re
class VQAEval:
def __init__(self, vqa=None, vqaRes=None, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
if vqa is not None:
self.params = {"question_id": vqa.getQuesIds()}
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params["question_id"]]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
print("computing accuracy")
step = 0
for quesId in quesIds:
resAns = res[quesId]["answer"]
resAns = resAns.replace("\n", " ")
resAns = resAns.replace("\t", " ")
resAns = resAns.strip()
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
gtAcc = []
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]["answers"]:
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
for gtAnsDatum in gts[quesId]["answers"]:
otherGTAns = [
item for item in gts[quesId]["answers"] if item != gtAnsDatum
]
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
quesType = gts[quesId]["question_type"]
ansType = gts[quesId]["answer_type"]
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step % 100 == 0:
self.updateProgress(step / float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
print("Done computing accuracy")
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = " ".join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
self.accuracy["perQuestionType"] = {
quesType: round(
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
self.n,
)
for quesType in accQuesType
}
self.accuracy["perAnswerType"] = {
ansType: round(
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
)
for ansType in accAnsType
}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100 * acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength * progress))
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
"#" * block + "-" * (barLength - block), int(progress * 100), status
)
sys.stdout.write(text)
sys.stdout.flush()

654
models/criteria.py Normal file
View file

@ -0,0 +1,654 @@
from functools import lru_cache
import torch
import torch.nn.functional as F
from torch import nn
from models.utils import allgather_wgrad
from utils.dist import get_rank, get_world_size
from utils.easydict import EasyDict
def get_sim(
x_proj: torch.Tensor,
y_proj: torch.Tensor,
temp=1.0,
):
"""calculate pair-wise similarity between two modalities x and y.
Args:
x_proj (torch.Tensor): The representation of modality x. Shape: [B,T,C] or [B,C].
y_proj (torch.Tensor): The representation of modality y. Shape: [B,C].
temp (torch.Tensor): The temperature. Shape: [].
Returns: The similarity between modality x and y. Shape: [B,B].
"""
x_proj = F.normalize(x_proj, dim=-1)
y_proj = F.normalize(y_proj, dim=-1)
assert x_proj.dim() in [2, 3]
assert y_proj.dim() == 2
if x_proj.dim() == 2:
sim_x2y = torch.einsum("md,nd->mn", x_proj, y_proj) / temp # (B,B)
else:
sim_x2y = torch.einsum("mld,nd->mln", x_proj, y_proj).mean(1) / temp # (B,B)
sim_y2x = sim_x2y.T
return sim_x2y, sim_y2x
class ContMatchLoss(nn.Module):
def __init__(self):
super(ContMatchLoss, self).__init__()
@torch.no_grad()
def get_mask(self, sim, idx=None, normalize=False):
"""
Args:
sim (torch.Tensor): The similarity between videos and texts. shape: (B, B).
idx (torch.Tensor): The index for each video. Shape: [B].
normalize (bool): If true, make row sum equal to 1
"""
if idx is not None:
idx = idx.view(-1, 1)
mask = torch.eq(idx, idx.T).to(sim.dtype)
if normalize:
mask = mask / mask.sum(1, keepdim=True)
else:
mask = torch.zeros_like(sim)
mask.fill_diagonal_(1)
return mask # `1` mark valid/matched location
@lru_cache(maxsize=16)
def get_gather_args(self):
"""obtain the args for all_gather
Returns: dict.
"""
return EasyDict({"world_size": get_world_size(), "rank": get_rank()})
class STC_STM_Loss(ContMatchLoss):
"""Contrastive and matching losses"""
def __init__(self):
super(STC_STM_Loss, self).__init__()
def stc_loss(
self,
temporal_proj: torch.Tensor,
spatial_proj: torch.Tensor,
idx: torch.Tensor,
temp=1.0,
all_gather=True
):
"""forward to calculate the loss
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
idx (torch.Tensor): The index for each example. Shape: [B,].
temp (torch.Tensor): The temperature. Shape: [].
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
"""
if all_gather:
gather_args = self.get_gather_args()
temporal_proj = allgather_wgrad(temporal_proj, gather_args)
spatial_proj = allgather_wgrad(spatial_proj, gather_args)
if idx is not None:
idx = allgather_wgrad(idx, gather_args)
sim_t2s, sim_s2t = get_sim(temporal_proj, spatial_proj, temp)
with torch.no_grad():
sim_t2s_targets = self.get_mask(sim_t2s, idx=idx, normalize=True)
sim_s2t_targets = sim_t2s_targets
loss_t2s = -torch.sum(F.log_softmax(sim_t2s, dim=1) * sim_t2s_targets, dim=1).mean()
loss_s2t = -torch.sum(F.log_softmax(sim_s2t, dim=1) * sim_s2t_targets, dim=1).mean()
loss_stc = (loss_t2s + loss_s2t) / 2
return loss_stc
def stm_loss(
self,
grounding_expert,
stm_head,
# temp,
spatial_embeds_orig,
temporal_embeds_orig,
temporal_proj,
spatial_proj,
idx,
generation=False,
temp=1.0
):
spatial_embeds = spatial_embeds_orig.clone()
temporal_embeds = temporal_embeds_orig.clone()
with torch.no_grad():
sim_s2t, sim_t2s = get_sim(temporal_proj, spatial_proj, temp)
spatial_atts = torch.ones(
spatial_embeds.size()[:-1], dtype=torch.long, device=spatial_embeds.device
)
temporal_atts = torch.ones(
temporal_embeds.size()[:-1], dtype=torch.long, device=temporal_embeds.device
)
weights_s2t = F.softmax(sim_s2t + 1e-4, dim=1) # (N, N)
weights_t2s = F.softmax(sim_t2s + 1e-4, dim=1)
mask = self.get_mask(sim_s2t, idx=idx).bool()
weights_s2t.masked_fill_(mask, 0)
weights_t2s.masked_fill_(mask, 0)
weights_s2t = torch.nan_to_num_(weights_s2t, nan=1e-2, posinf=1e-2, neginf=1e-2)
weights_t2s = torch.nan_to_num_(weights_t2s, nan=1e-2, posinf=1e-2, neginf=1e-2)
if generation:
with torch.no_grad():
output = grounding_expert(
encoder_embeds=temporal_embeds,
attention_mask=temporal_atts,
encoder_hidden_states=spatial_embeds,
encoder_attention_mask=spatial_atts,
return_dict=True,
)
pos_feats = output.last_hidden_state
return pos_feats
else:
# select a hard negatives within the batch
spatial_neg_indices = torch.multinomial(weights_s2t, 1).squeeze()
temporal_neg_indices = torch.multinomial(weights_t2s, 1).squeeze()
spatial_embeds_neg = spatial_embeds[spatial_neg_indices] # [B, L, c]
temporal_embeds_neg = temporal_embeds[temporal_neg_indices] # [B, L, d]
# temporal_atts_neg = temporal_atts[temporal_neg_indices]
# concat embeddings
spatial_embeds_all = torch.cat([spatial_embeds, spatial_embeds_neg, spatial_embeds], dim=0)
temporal_embeds_all = torch.cat([temporal_embeds, temporal_embeds, temporal_embeds_neg], dim=0)
spatial_atts_all = torch.cat([spatial_atts, spatial_atts, spatial_atts], dim=0)
temporal_atts_all = torch.cat([temporal_atts, temporal_atts, temporal_atts], dim=0)
output = grounding_expert(
inputs_embeds=temporal_embeds_all,
attention_mask=temporal_atts_all,
cross_embeds=spatial_embeds_all,
cross_attention_mask=spatial_atts_all,
)
stm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
stm_logits = stm_head(stm_embeds) # [3*B, 2]
bs = stm_logits.shape[0] // 3
stm_labels = stm_logits.new_ones(3 * bs, dtype=torch.long)
stm_labels[bs:] = 0
loss_stm = F.cross_entropy(stm_logits, stm_labels)
pos_feats = output.last_hidden_state[:bs]
return loss_stm, pos_feats
class VCC_VCM_Loss(ContMatchLoss):
"""Contrastive and matching losses"""
def __init__(self):
super(VCC_VCM_Loss, self).__init__()
def vcc_loss(
self,
vis_proj: torch.Tensor,
cap_proj: torch.Tensor,
idx: torch.Tensor,
temp=1.0,
all_gather=True
):
"""forward to calculate the loss
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
idx (torch.Tensor): The index for each example. Shape: [B,].
temp (torch.Tensor): The temperature. Shape: [].
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
"""
if all_gather:
gather_args = self.get_gather_args()
vis_proj = allgather_wgrad(vis_proj, gather_args)
cap_proj = allgather_wgrad(cap_proj, gather_args)
if idx is not None:
idx = allgather_wgrad(idx, gather_args)
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
with torch.no_grad():
sim_v2c_targets = self.get_mask(sim_v2c, idx=idx, normalize=True)
sim_c2v_targets = sim_v2c_targets
loss_v2c = -torch.sum(F.log_softmax(sim_v2c, dim=1) * sim_v2c_targets, dim=1).mean()
loss_c2v = -torch.sum(F.log_softmax(sim_c2v, dim=1) * sim_c2v_targets, dim=1).mean()
loss_vcc = (loss_v2c + loss_c2v) / 2
return loss_vcc
def vcm_loss(
self,
grounding_expert,
vcm_head,
vis_embeds_orig,
cap_embeds_orig,
vis_proj,
cap_proj,
cap_atts,
idx,
generation=False,
temp=1.0
):
vis_embeds = vis_embeds_orig.clone()
cap_embeds = cap_embeds_orig.clone()
with torch.no_grad():
sim_v2c, sim_c2v = get_sim(vis_proj, cap_proj, temp)
vis_atts = torch.ones(
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
)
weights_v2c = F.softmax(sim_v2c + 1e-4, dim=1) # (N, N)
weights_c2v = F.softmax(sim_c2v + 1e-4, dim=1)
mask = self.get_mask(weights_v2c, idx=idx).bool()
weights_v2c.masked_fill_(mask, 0)
weights_c2v.masked_fill_(mask, 0)
weights_v2c = torch.nan_to_num_(weights_v2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
weights_c2v = torch.nan_to_num_(weights_c2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
if generation:
with torch.no_grad():
output = grounding_expert(
encoder_embeds=cap_embeds,
attention_mask=cap_atts,
encoder_hidden_states=vis_embeds,
encoder_attention_mask=vis_atts,
return_dict=True,
)
pos_feats = output.last_hidden_state
return pos_feats
else:
# select a hard negatives within the batch
vis_neg_indices = torch.multinomial(weights_v2c, 1).squeeze()
cap_neg_indices = torch.multinomial(weights_c2v, 1).squeeze()
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, d]
cap_atts_neg = cap_atts[cap_neg_indices]
# concat embeddings
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
cap_embeds_all = torch.cat([cap_embeds, cap_embeds, cap_embeds_neg], dim=0)
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
cap_atts_all = torch.cat([cap_atts, cap_atts, cap_atts_neg], dim=0)
output = grounding_expert(
inputs_embeds=cap_embeds_all,
attention_mask=cap_atts_all,
cross_embeds=vis_embeds_all,
cross_attention_mask=vis_atts_all,
)
vcm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
vcm_logits = vcm_head(vcm_embeds) # [3*B, 2]
bs = vcm_logits.shape[0] // 3
vcm_labels = vcm_logits.new_ones(3 * bs, dtype=torch.long)
vcm_labels[bs:] = 0
loss_vcm = F.cross_entropy(vcm_logits, vcm_labels)
pos_feats = output.last_hidden_state[:bs]
return loss_vcm, pos_feats
class VHC_VHM_Loss(ContMatchLoss):
"""Contrastive and matching losses"""
def __init__(self):
super(VHC_VHM_Loss, self).__init__()
def vhc_loss(
self,
vis_proj: torch.Tensor,
hist_proj: torch.Tensor,
idx: torch.Tensor,
temp=1.0,
all_gather=True
):
"""forward to calculate the loss
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
idx (torch.Tensor): The index for each example. Shape: [B,].
temp (torch.Tensor): The temperature. Shape: [].
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
"""
if all_gather:
gather_args = self.get_gather_args()
vis_proj = allgather_wgrad(vis_proj, gather_args)
hist_proj = allgather_wgrad(hist_proj, gather_args)
if idx is not None:
idx = allgather_wgrad(idx, gather_args)
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
with torch.no_grad():
sim_v2h_targets = self.get_mask(sim_v2h, idx=idx, normalize=True)
sim_h2v_targets = sim_v2h_targets
loss_v2h = -torch.sum(F.log_softmax(sim_v2h, dim=1) * sim_v2h_targets, dim=1).mean()
loss_h2v = -torch.sum(F.log_softmax(sim_h2v, dim=1) * sim_h2v_targets, dim=1).mean()
loss_vhc = (loss_v2h + loss_h2v) / 2
return loss_vhc
def vhm_loss(
self,
grounding_expert,
vhm_head,
vis_embeds_orig,
hist_embeds_orig,
vis_proj,
hist_proj,
hist_atts,
idx,
generation=False,
temp=1.0,
):
vis_embeds = vis_embeds_orig.clone()
hist_embeds = hist_embeds_orig.clone()
with torch.no_grad():
sim_v2h, sim_h2v = get_sim(vis_proj, hist_proj, temp)
vis_atts = torch.ones(
vis_embeds.size()[:-1], dtype=torch.long, device=vis_embeds.device
)
weights_v2h = F.softmax(sim_v2h + 1e-4, dim=1) # (N, N)
weights_h2v = F.softmax(sim_h2v + 1e-4, dim=1)
mask = self.get_mask(weights_v2h, idx=idx).bool()
weights_v2h.masked_fill_(mask, 0)
weights_h2v.masked_fill_(mask, 0)
weights_v2h = torch.nan_to_num_(weights_v2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
weights_h2v = torch.nan_to_num_(weights_h2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
if generation:
with torch.no_grad():
output = grounding_expert(
encoder_embeds=hist_embeds,
attention_mask=hist_atts,
encoder_hidden_states=vis_embeds,
encoder_attention_mask=vis_atts,
return_dict=True,
# mode="fusion",
)
pos_feats = output.last_hidden_state
return pos_feats
else:
# select a hard negatives within the batch
vis_neg_indices = torch.multinomial(weights_v2h, 1).squeeze()
hist_neg_indices = torch.multinomial(weights_h2v, 1).squeeze()
vis_embeds_neg = vis_embeds[vis_neg_indices] # [B, L, c]
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
hist_atts_neg = hist_atts[hist_neg_indices]
# concat embeddings
vis_embeds_all = torch.cat([vis_embeds, vis_embeds_neg, vis_embeds], dim=0)
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
vis_atts_all = torch.cat([vis_atts, vis_atts, vis_atts], dim=0)
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
output = grounding_expert(
inputs_embeds=hist_embeds_all,
attention_mask=hist_atts_all,
cross_embeds=vis_embeds_all,
cross_attention_mask=vis_atts_all,
)
vhm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
vhm_logits = vhm_head(vhm_embeds) # [3*B, 2]
bs = vhm_logits.shape[0] // 3
vhm_labels = vhm_logits.new_ones(3 * bs, dtype=torch.long)
vhm_labels[bs:] = 0
loss_vhm = F.cross_entropy(vhm_logits, vhm_labels)
pos_feats = output.last_hidden_state[:bs]
return loss_vhm, pos_feats
class CHC_CHM_Loss(ContMatchLoss):
"""Contrastive and matching losses"""
def __init__(self):
super(CHC_CHM_Loss, self).__init__()
def chc_loss(
self,
cap_proj: torch.Tensor,
hist_proj: torch.Tensor,
idx: torch.Tensor,
temp=1.0,
all_gather=True
):
"""forward to calculate the loss
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
idx (torch.Tensor): The index for each example. Shape: [B,].
temp (torch.Tensor): The temperature. Shape: [].
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
"""
if all_gather:
gather_args = self.get_gather_args()
cap_proj = allgather_wgrad(cap_proj, gather_args)
hist_proj = allgather_wgrad(hist_proj, gather_args)
if idx is not None:
idx = allgather_wgrad(idx, gather_args)
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
with torch.no_grad():
sim_c2h_targets = self.get_mask(sim_c2h, idx=idx, normalize=True)
sim_h2c_targets = sim_c2h_targets
loss_c2h = -torch.sum(F.log_softmax(sim_c2h, dim=1) * sim_c2h_targets, dim=1).mean()
loss_h2c = -torch.sum(F.log_softmax(sim_h2c, dim=1) * sim_h2c_targets, dim=1).mean()
loss_chc = (loss_c2h + loss_h2c) / 2
return loss_chc
def chm_loss(
self,
grounding_expert,
chm_head,
cap_embeds_orig,
hist_embeds_orig,
cap_proj,
hist_proj,
cap_atts,
hist_atts,
idx,
generation=False,
temp=1.0
):
cap_embeds = cap_embeds_orig.clone()
hist_embeds = hist_embeds_orig.clone()
with torch.no_grad():
sim_c2h, sim_h2c = get_sim(cap_proj, hist_proj, temp)
weights_c2h = F.softmax(sim_c2h + 1e-4, dim=1) # (N, N)
weights_h2c = F.softmax(sim_h2c + 1e-4, dim=1)
mask = self.get_mask(weights_c2h, idx=idx).bool()
weights_c2h.masked_fill_(mask, 0)
weights_h2c.masked_fill_(mask, 0)
weights_c2h = torch.nan_to_num_(weights_c2h, nan=1e-2, posinf=1e-2, neginf=1e-2)
weights_h2c = torch.nan_to_num_(weights_h2c, nan=1e-2, posinf=1e-2, neginf=1e-2)
if generation:
with torch.no_grad():
output = grounding_expert(
encoder_embeds=hist_embeds,
attention_mask=hist_atts,
encoder_hidden_states=cap_embeds,
encoder_attention_mask=cap_atts,
return_dict=True,
)
pos_feats = output.last_hidden_state
return pos_feats
else:
# select a hard negatives within the batch
cap_neg_indices = torch.multinomial(weights_c2h, 1).squeeze()
hist_neg_indices = torch.multinomial(weights_h2c, 1).squeeze()
cap_embeds_neg = cap_embeds[cap_neg_indices] # [B, L, c]
cap_atts_neg = cap_atts[cap_neg_indices]
hist_embeds_neg = hist_embeds[hist_neg_indices] # [B, L, d]
hist_atts_neg = hist_atts[hist_neg_indices]
# concat embeddings
cap_embeds_all = torch.cat([cap_embeds, cap_embeds_neg, cap_embeds], dim=0)
hist_embeds_all = torch.cat([hist_embeds, hist_embeds, hist_embeds_neg], dim=0)
cap_atts_all = torch.cat([cap_atts, cap_atts_neg, cap_atts], dim=0)
hist_atts_all = torch.cat([hist_atts, hist_atts, hist_atts_neg], dim=0)
output = grounding_expert(
inputs_embeds=hist_embeds_all,
attention_mask=hist_atts_all,
cross_embeds=cap_embeds_all,
cross_attention_mask=cap_atts_all,
)
chm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
chm_logits = chm_head(chm_embeds) # [3*B, 2]
bs = chm_logits.shape[0] // 3
chm_labels = chm_logits.new_ones(3 * bs, dtype=torch.long)
chm_labels[bs:] = 0
loss_chm = F.cross_entropy(chm_logits, chm_labels)
pos_feats = output.last_hidden_state[:bs]
return loss_chm, pos_feats
class MLMLoss(nn.Module):
"""masked language modeling loss."""
def __init__(self, masking_prob, tokenizer):
super(MLMLoss, self).__init__()
self.tokenizer = tokenizer
self.masking_prob = masking_prob
def mlm_loss(
self,
text_encoder,
text,
text_embeds,
vision_embeds,
vision_atts,
):
input_ids = text.input_ids.clone()
labels = input_ids.clone()
probability_matrix = torch.full(labels.shape, self.masking_prob)
input_ids, labels = self.mask(
input_ids,
text_encoder.config.vocab_size,
input_ids.device,
targets=labels,
probability_matrix=probability_matrix,
)
# intermediate_mlm_output = text_encoder.bert(
# input_ids,
# attention_mask=text.attention_mask,
# encoder_hidden_states=vision_embeds,
# encoder_attention_mask=vision_atts,
# return_dict=True,
# # mode="text",
# )
# text_embeds = intermediate_mlm_output.last_hidden_state
mlm_output = text_encoder(
encoder_embeds=text_embeds,
attention_mask=text.attention_mask,
encoder_hidden_states=vision_embeds,
encoder_attention_mask=vision_atts,
return_dict=True,
labels=labels,
soft_labels=None,
# mode="fusion",
)
return mlm_output.loss
def mask(
self,
input_ids,
vocab_size,
device,
targets=None,
masked_indices=None,
probability_matrix=None,
):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
# We only compute loss on masked tokens
targets[~masked_indices] = -100
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
)
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids

View file

View file

@ -0,0 +1,286 @@
import logging
import math
import einops
import torch
from einops import rearrange
from timm.models.layers.drop import DropPath
from torch import nn
from torch.nn import LayerNorm, Linear, MultiheadAttention
logger = logging.getLogger(__name__)
class STAdapter(nn.Module):
"""ST Adapter"""
def __init__(
self,
kernel_size=(3, 3, 3),
input_dim=768,
hidden_dim=384,
img_size=224,
patch_size=16,
drop_prob=0.1,
):
super(STAdapter, self).__init__()
self.kernel_size = kernel_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.h = self.w = img_size // patch_size
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, input_dim)
self.act = nn.ReLU()
self.conv = nn.Conv3d(
hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim
)
self.droppath = DropPath(drop_prob=drop_prob)
self.scale = nn.parameter.Parameter(torch.zeros([]))
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
x = self.linear1(x)
cls = x[:, :, :1, :]
tokens = x[:, :, 1:, :]
tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous()
tokens = self.conv(tokens)
tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c")
x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c]
x = self.act(x)
x = self.linear2(x)
return shortcut + self.scale * self.droppath(x)
class SpatialAttention(nn.Module):
"""Perfrom spatial self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
super(SpatialAttention, self).__init__()
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.linear = Linear(input_dim, input_dim)
self.droppath = DropPath(droppath_rate)
# self.scale = nn.parameter.Parameter(torch.zeros([]))
self.scale = 1.0
def forward(self, x: torch.Tensor):
if x.shape[1] == 1:
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x # return self if media is image
shortcut = x
x = einops.rearrange(x, 'b t l c -> (b t) l c')
x = self.norm(x)
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0])
x = shortcut + self.scale * self.droppath(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
class TemporalAttention(nn.Module):
"""perform temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super(TemporalAttention, self).__init__()
self._input_dim = input_dim
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.linear = Linear(input_dim, input_dim)
self.droppath = DropPath(droppath_rate)
# self.scale = nn.parameter.Parameter(torch.zeros([]))
self.scale = 1.0
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
shortcut = x
x = einops.rearrange(x, "b t l c -> (b l) t c")
x = self.norm(x)
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0])
x = shortcut + self.scale * self.droppath(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
class WindowTemporalAttention(nn.Module):
"""perform windowed temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
self._input_dim = input_dim
self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.droppath = DropPath(droppath_rate)
self.scale = nn.parameter.Parameter(torch.zeros([]))
self.wh, self.ww = window_size
# logger.info(f"WindowTemporalAttention: window_size: {window_size}")
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
h = w = int(math.sqrt(x.shape[2] - 1))
cls_token = x[:, :, :1, :]
x = einops.rearrange(
x[:, :, 1:, :],
"b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c",
nh=h // self.wh,
wh=self.wh,
nw=w // self.ww,
ww=self.ww,
)
x = self.norm(x)
x = self.temporal_attn(x, x, x)[0]
x = einops.rearrange(
x,
"(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c",
wh=self.wh,
ww=self.ww,
nh=h // self.wh,
nw=w // self.ww,
)
# add back cls token.
x = torch.concat([cls_token, x], dim=2)
return shortcut + self.scale * self.droppath(x)
class X_CLIP(nn.Module):
"""perform windowed temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
d_model = input_dim
self.message_fc = nn.Linear(d_model, d_model)
self.message_ln = LayerNorm(d_model, eps=1e-12)
self.message_attn = nn.MultiheadAttention(d_model, d_model // 64)
self.num_prompts = num_prompts
self.droppath = DropPath(droppath_rate)
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c]
msg_token = rearrange(msg_token, "b t c -> t b c")
msg_token = msg_token + self.droppath(
self.message_attn(msg_token, msg_token, msg_token)[0]
)
msg_token = rearrange(msg_token, "t b c -> b t c")
# replace the last prompt token with msg_token.
x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c]
return x
class TemporalS4(nn.Module):
"""perform temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
from .s4 import S4
self._input_dim = input_dim
self.norm = LayerNorm(input_dim, eps=1e-12)
self.droppath = DropPath(droppath_rate)
self.scale = nn.parameter.Parameter(torch.zeros([]))
self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True)
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b c (t l)")
x, _ = self.s4(x)
x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1])
return shortcut + self.scale * self.droppath(x)

358
models/setup.py Normal file
View file

@ -0,0 +1,358 @@
import copy
import os.path as osp
import glog as logger
import torch
from torch.utils.data import ConcatDataset
from models.backbones.beit.builder import interpolate_pos_embed_beit
from models.backbones.bert.tokenization_bert import BertTokenizer
from transformers import T5Tokenizer, BartTokenizer, LlamaTokenizer
from utils.optimizer import create_optimizer
from utils.scheduler import create_scheduler
from datasets.dataloader import load_dataloaders
from datasets.pretraining import load_datasets as load_datasets_stage_1
from datasets.visdial_dataset import load_visdial_dataset
from datasets.champagne_dataset import load_champagne_dataset
from datasets.nextqa_dataset import load_nextqa_dataset
from datasets.avsd_dataset import load_avsd_dataset
# from datasets.avsd_dataset_like_mixer import load_avsd_dataset
from processors.blip_processors import Blip2ImageTrainProcessor
from processors.blip_processors import BlipCaptionProcessor, BlipDialogProcessor
from utils.init import set_training_steps
# from models.v2dial import V2Dial, V2DialBase
from models.v2dial import V2DialBase, V2Dial, V2DialNoMoes
# from datasets.avsd_dataset import get_dataset, AVSDDataSet
from torch.utils.data import DataLoader
def setup_model(
config, has_decoder=False, pretrain=False, find_unused_parameters=True
):
logger.info("Creating model")
if config['stage'] == 'stage_1':
config = copy.deepcopy(config)
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = V2DialBase(config=config, expert_tokenizer=tokenizer)
model = V2DialBase(config)
model = model.to(torch.device('cuda'))
model_without_ddp = model
optimizer = create_optimizer(config, model)
scheduler = create_scheduler(config, optimizer)
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
if config['distributed']:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[config['gpu']],
find_unused_parameters=find_unused_parameters, # `False` for image-only task
)
start_epoch = 0
global_step = 0
webvid_step = 0
cc3m_step = 0
if osp.isfile(config['pretrained_path']):
logger.info(f"Loading checkpoint from {config['pretrained_path']}")
checkpoint = torch.load(config['pretrained_path'], map_location="cpu")
state_dict = checkpoint["model"]
if config.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
elif not pretrain: # downstream init from pretrained ckpt
# interpolate positional embeddings.
state_dict = interpolate_pos_embed_beit(state_dict, model_without_ddp)
#TODO Might need to update to match the MoEs
if not config.evaluate: # finetuning from a pretarined weights.
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if not has_decoder:
del state_dict[key]
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
# only for generation tasks like VQA
if has_decoder and "text_encoder" in key:
if "layer" in key:
encoder_keys = key.split(".")
layer_num = int(encoder_keys[4])
if layer_num < config.model.text_encoder.fusion_layer:
del state_dict[key]
continue
else:
decoder_layer_num = layer_num - 9
encoder_keys[4] = str(decoder_layer_num)
encoder_key = ".".join(encoder_keys)
else:
encoder_key = key
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
state_dict[decoder_key] = state_dict[key]
del state_dict[key]
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
else:
logger.warning("No pretrained checkpoint provided, training from scratch")
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
webvid_step,
cc3m_step,
config
)
else:
# config = copy.deepcopy(config)
# if config['use_original_feats']:
# model = AVSDBart(config)
# else:
# # model = V2Dial(config, tokenizer_experts, tokenizer_enc_dec)
# if config.use_moes:
model = V2Dial(config)
# else:
# model = V2DialNoMoes(config)
model = model.to(torch.device('cuda'))
model_without_ddp = model
optimizer = None
scheduler = None
scaler = None
start_epoch = 0
global_step = 0
if config['stage'] == 'stage_3':
visdial_step = 0
avsd_step = 0
nextqa_step = 0
ckpt_path = config.pretrained_path_resume if config.resume else config.pretrained_path_prev_stage
if config.generating:
ckpt_path = config.best_ckpt_path
if osp.isfile(ckpt_path):
logger.info(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model"]
if config.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
if config['stage'] == 'stage_3':
visdial_step = checkpoint['visdial_step']
avsd_step = checkpoint['avsd_step']
next_step = checkpoint['nextqa_step']
if config['stage'] in ['stage_2', 'stage_3'] and config.use_moes:
# Init. the history expert erights with the caption expert weights
p_names = [
'moe_layers.{}.norm_hist.weight',
'moe_layers.{}.mlp_hist.fc1.weight',
'moe_layers.{}.mlp_hist.fc1.bias',
'moe_layers.{}.mlp_hist.fc2.weight',
'moe_layers.{}.mlp_hist.fc2.bias',
]
for moe_layer_idx in range(config.num_moe_modality_layers):
for p_name in p_names:
p_hist_name = p_name.format(moe_layer_idx)
if p_hist_name not in state_dict:
p_cap_name = p_hist_name.replace('hist', 'cap')
state_dict[p_hist_name] = state_dict[p_cap_name].clone()
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {ckpt_path}")
else:
logger.warning("No pretrained checkpoint provided, training from scratch")
if config['training']:
optimizer = create_optimizer(config, model_without_ddp)
scheduler = create_scheduler(config, optimizer)
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
elif config['generating']:
model.llm.set_input_embeddings(model.text_embedding)
if config['distributed']:
static_graph=config.stage!='stage_1'
if len(config.media_train) > 0:
static_graph = False
model = torch.nn.parallel.DistributedDataParallel(
model_without_ddp,
device_ids=[config['gpu']],
find_unused_parameters=find_unused_parameters, # `False` for image-only task
static_graph=static_graph
)
if config['stage'] == 'stage_3':
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
visdial_step,
avsd_step,
nextqa_step,
config
)
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
config
)
def setup_data(config):
logger.info("[INFO] Creating datasets")
# define the processors
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
if config['stage'] == 'stage_1':
text_processor = BlipCaptionProcessor(max_words=config.max_cap_len)
if config['debugging']:
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
else:
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'train')
val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
# cc3m_dataset = ConcatDataset([train_datasets['cc3m'], val_datasets['cc3m']])
# webvid_dataset = ConcatDataset([train_datasets['webvid'], val_datasets['webvid']])
# train_datasets = [cc3m_dataset, webvid_dataset]
train_datasets = list(train_datasets.values())
val_datasets = list(val_datasets.values())
batch_sizes = [config['batch_size_cc3m'], config['batch_size_webvid']]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
# val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'test')
# val_dataloader = load_dataloaders(config, val_datasets, 'test', output_dict=True)
if config['stage'] == 'stage_2':
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
train_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'train')]
val_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'val')]
batch_sizes = [config['batch_size_champagne']]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
if config['stage'] == 'stage_3':
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
train_datasets = []
val_datasets = []
for medium in config['media_train']:
if medium == 'visdial':
load_dataset_fn = load_visdial_dataset
elif medium == 'avsd':
load_dataset_fn = load_avsd_dataset
elif medium == 'nextqa':
load_dataset_fn = load_nextqa_dataset
# elif medium == 'champagne':
# load_dataset_fn = load_champagne_dataset
train_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'train'))
for medium in config['media_val']:
if medium == 'visdial':
load_dataset_fn = load_visdial_dataset
elif medium == 'avsd':
load_dataset_fn = load_avsd_dataset
elif medium == 'nextqa':
load_dataset_fn = load_nextqa_dataset
# elif medium == 'champagne':
# load_dataset_fn = load_champagne_dataset
val_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'val'))
batch_sizes = [d.batch_size for d in train_datasets]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
return train_dataloaders, val_dataloaders
def setup_data_test(config):
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
if config.media_test == 'visdial':
load_dataset_fn = load_visdial_dataset
elif config.media_test == 'avsd':
load_dataset_fn = load_avsd_dataset
elif config.media_test == 'nextqa':
load_dataset_fn = load_nextqa_dataset
test_dataset = load_dataset_fn(config, vis_processor, text_processor, 'test')
test_dataloader = DataLoader(
test_dataset, shuffle=False, batch_size=test_dataset.batch_size)
return test_dataloader
# def setup_data_test(config, args):
# tokenizer_experts = BertTokenizer.from_pretrained('bert-base-uncased')
# tokenizer_enc_dec = None
# if config.enc_dec_family == 'flan_t5':
# tokenizer_enc_dec = T5Tokenizer.from_pretrained(config.enc_dec_name)
# elif config.enc_dec_family == 'bart':
# tokenizer_enc_dec = BartTokenizer.from_pretrained(config.enc_dec_name)
# if config['tie_embeddings']:
# tokenizer_experts = tokenizer_enc_dec
# if config['medium'] == 'avsd':
# test_dataset = AVSDDataSet(config, 'avsd', tokenizer_experts, tokenizer_enc_dec, 'test')
# test_dataloader = DataLoader(
# test_dataset, shuffle=False, batch_size=test_dataset.batch_size, collate_fn=test_dataset.collate_fn)
# return test_dataloader

266
models/utils.py Normal file
View file

@ -0,0 +1,266 @@
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import interpolate
from typing import List
logger = logging.getLogger(__name__)
class MLM:
def __init__(
self,
mask_token: int,
padding_token: int,
no_mask_tokens: List[int],
n_tokens: int,
masking_prob: float = 0.15,
randomize_prob: float = 0.1,
no_change_prob: float = 0.1
):
self.mask_token = mask_token
self.padding_token = padding_token
self.no_mask_tokens = list(set(no_mask_tokens + [padding_token, mask_token]))
self.n_tokens = n_tokens
self.masking_prob = masking_prob
self.randomize_prob = randomize_prob
self.no_change_prob = no_change_prob
def __call__(self, x: torch.Tensor):
full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
for tok in self.no_mask_tokens:
full_mask &= x != tok # unmask unwanted tokens --> 0
unchanged_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
mask = full_mask & ~random_token_mask & ~unchanged_mask
y = x.clone().detach()
x.masked_fill_(mask, self.mask_token)
x[random_token_idx] = random_tokens
y.masked_fill_(~full_mask, self.padding_token)
return x, y
def _init_transformer_weights(module, initializer_range=0.02):
"""Initialize the weights. Copied from transformers ViT/Bert model init"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
"""
Add/Remove extra temporal_embeddings as needed.
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
temp_embed_old: (1, num_frames_old, 1, d)
temp_embed_new: (1, num_frames_new, 1, d)
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
"""
# TODO zero pad
num_frms_new = temp_embed_new.shape[1]
num_frms_old = temp_embed_old.shape[1]
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
if num_frms_new > num_frms_old:
if add_zero:
temp_embed_new[
:, :num_frms_old
] = temp_embed_old # untrained embeddings are zeros.
else:
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
elif num_frms_new < num_frms_old:
temp_embed_new = temp_embed_old[:, :num_frms_new]
else: # =
temp_embed_new = temp_embed_old
return temp_embed_new
def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
"""
temp_embed_old: (1, num_frames_old, 1, d)
Returns:
temp_embed_new: (1, num_frames_new, 1, d)
"""
temp_embed_old = temp_embed_old.squeeze(2).permute(
0, 2, 1
) # (1, d, num_frames_old)
temp_embed_new = F.interpolate(
temp_embed_old, num_frames_new, mode="linear"
) # (1, d, num_frames_new)
temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
2
) # (1, num_frames_new, 1, d)
return temp_embed_new
def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
"""
Args:
pos_embed_old: (1, L_old, d), pre-trained
pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
num_patches_new:
"""
# interpolate position embedding
embedding_size = pos_embed_old.shape[-1]
num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches_new ** 0.5)
if orig_size != new_size:
# class_token and dist_token are kept unchanged
# the extra tokens seems always at the beginning of the position embedding
extra_tokens = pos_embed_old[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_old[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
return interpolated_pos_embed
else:
return pos_embed_old
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
"""
Args:
state_dict_old: loaded state dict
state_dict_new: state dict for model with new image size
patch_shape_new: new model patch_shape
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
"""
all_keys = list(state_dict_old.keys())
for key in all_keys:
if "relative_position_index" in key:
state_dict_old.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = state_dict_old[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = state_dict_new[key].size()
dst_patch_shape = patch_shape_new
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
dst_patch_shape[1] * 2 - 1
)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
# logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
# key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
# logger.info("Original positions = %s" % str(x))
# logger.info("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind="cubic")
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy))
.contiguous()
.view(-1, 1)
.to(rel_pos_bias.device)
)
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict_old[key] = new_rel_pos_bias
return state_dict_old
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*repeat_idx)
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
)
return torch.index_select(x, dim, order_index.to(x.device))
def mask_logits(target, mask):
return target * mask + (1 - mask) * (-1e10)
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor, args):
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None,
)
allgather_wgrad = AllGather.apply

2213
models/v2dial.py Normal file

File diff suppressed because it is too large Load diff