initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

1216
models/backbones/Qformer.py Executable file

File diff suppressed because it is too large Load diff

View file

247
models/backbones/base_model.py Executable file
View file

@ -0,0 +1,247 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import numpy as np
import torch
import torch.nn as nn
from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
from models.common.utils import get_abs_path, is_url
from omegaconf import OmegaConf
class BaseModel(nn.Module):
"""Base class for models."""
def __init__(self):
super().__init__()
@property
def device(self):
return list(self.parameters())[0].device
def load_checkpoint(self, url_or_filename):
"""
Load from a finetuned checkpoint.
This should expect no mismatch in the model keys and the checkpoint keys.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
if "model" in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
@classmethod
def from_pretrained(cls, model_type):
"""
Build a pretrained model from default configuration file, specified by model_type.
Args:
- model_type (str): model type, specifying architecture and checkpoints.
Returns:
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
"""
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
model = cls.from_config(model_cfg)
return model
@classmethod
def default_config_path(cls, model_type):
assert (
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
), "Unknown model type {}".format(model_type)
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
def load_checkpoint_from_config(self, cfg, **kwargs):
"""
Load checkpoint as specified in the config file.
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
When loading the pretrained model, each task-specific architecture may define their
own load_from_pretrained() method.
"""
load_finetuned = cfg.get("load_finetuned", True)
if load_finetuned:
finetune_path = cfg.get("finetuned", None)
assert (
finetune_path is not None
), "Found load_finetuned is True, but finetune_path is None."
self.load_checkpoint(url_or_filename=finetune_path)
else:
# load pre-trained weights
pretrain_path = cfg.get("pretrained", None)
assert "Found load_finetuned is False, but pretrain_path is None."
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
def before_evaluation(self, **kwargs):
pass
def show_n_params(self, return_str=True):
tot = 0
for p in self.parameters():
w = 1
for x in p.shape:
w *= x
tot += w
if return_str:
if tot >= 1e6:
return "{:.1f}M".format(tot / 1e6)
else:
return "{:.1f}K".format(tot / 1e3)
else:
return tot
class BaseEncoder(nn.Module):
"""
Base class for primitive encoders, such as ViT, TimeSformer, etc.
"""
def __init__(self):
super().__init__()
def forward_features(self, samples, **kwargs):
raise NotImplementedError
@property
def device(self):
return list(self.parameters())[0].device
class SharedQueueMixin:
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
batch_size = image_feats.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
if idxs is not None:
idxs = concat_all_gather(idxs)
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
class MomentumDistilationMixin:
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data = param_m.data * self.momentum + param.data * (
1.0 - self.momentum
)
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
torch.distributed.all_reduce(all_gradients)
return all_gradients[torch.distributed.get_rank()]
def all_gather_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = torch.distributed.get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
# tensor_all = GatherLayer.apply(tensors)
tensor_all = GatherLayer.apply(tensors)
return torch.cat(tensor_all, dim=0)
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
# if use distributed training
if not is_dist_avail_and_initialized():
return tensor
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
)
return torch.index_select(x, dim, order_index.to(x.device))

View file

View file

@ -0,0 +1,107 @@
import logging
import torch
from models.utils import (interpolate_pos_relative_bias_beit,
load_temp_embed_with_mismatch)
logger = logging.getLogger(__name__)
def interpolate_pos_embed_beit(state_dict, new_model):
"""interpolate the positional embeddings.
The spatial pe is relative and temporal pe is absolute.
additional temporal pe is padded with 0.
Args:
state_dict (dict): The state_dict.
new_model (nn.Module): The created model.
Returns: dict. The state_dict with updated positional embeddings.
"""
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=state_dict,
state_dict_new=new_model.state_dict(),
patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape,
)
# absolute temporal pos bias
temporal_pe_key = "beit.embeddings.temporal_position_embeddings"
if temporal_pe_key in state_dict:
logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
temp_embed_old=state_dict[temporal_pe_key],
temp_embed_new=new_model.state_dict()[temporal_pe_key],
)
return state_dict
def extract_beit_from_vindlu(vindlu_state_dict):
beit_state_dict = {}
beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k]
for param_name in beit_param_names:
new_name = param_name.replace('vision_encoder.', '')
beit_state_dict[new_name] = vindlu_state_dict[param_name]
return beit_state_dict
def build_beit(model_config, image_res, checkpoint=False):
"""build beit with configuration.
Args:
config (dict): The configs for beit.
image_res (int): The image resolution.
checkpoint (bool): Whether to enable gradient checkpointing.
Returns: nn.Module
"""
from .st_beit import BeitConfig as config_cls
from .st_beit import BeitModel as model_cls
vindlu_state_dict = torch.load(model_config['vindlu_path'])['model']
state_dict = extract_beit_from_vindlu(vindlu_state_dict)
model_config = model_config['beit_config_json']
logger.info(
f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}."
)
# BEiT uses average pooled tokens instead of [CLS] used by other models
aux_kwargs = {"add_pooling_layer": True}
# tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs)
# tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs)
# state_dict = tmp_model.state_dict()
# del tmp_model
logger.info(f"Init new model with new image size {image_res}, and load weights.")
# other_cfg = model_config.temporal_modeling
other_cfg = {}
vit_config = config_cls.from_pretrained(
model_config['pretrained'], image_size=image_res, **other_cfg
)
# vit_config.update(model_config)
model = model_cls(config=vit_config, **aux_kwargs)
if checkpoint:
model.gradient_checkpointing_enable()
# interpolate relative pos bias
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=state_dict,
state_dict_new=model.state_dict(),
patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
)
# del prompt_bias_table
for k in list(state_dict.keys()):
if "prompt_bias_table" in k:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
logger.info(msg)
return model

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,71 @@
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
"""build text encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_size = model_config['expert_size']
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
# bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
if expert_type=='modality':
if modality_type == 'vis':
bert_config.cross_attention_freq = 2
else:
bert_config.cross_attention_freq = -1
else:
bert_config.cross_attention_freq = 1
if pretrain:
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
output_loading_info=True,
)
else:
text_encoder, loading_info = BertModel.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
add_pooling_layer=True,
output_loading_info=True,
)
return text_encoder
def build_bert_decoder(model_config, checkpoint):
"""build text decoder the same as the multimodal encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.fusion_layer = 0
bert_config.num_hidden_layers = (
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
)
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
output_loading_info=True,
)
return text_decoder

View file

@ -0,0 +1,546 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Bert."""
import collections
import os
import unicodedata
from typing import List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(PreTrainedTokenizer):
r"""
Construct a BERT tokenizer. Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to do basic tokenization before WordPiece.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
vocab_file)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab, unk_token=self.unk_token)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X ``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix +
"-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(
vocab_file)
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(
set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens

File diff suppressed because it is too large Load diff

268
models/backbones/blip2.py Executable file
View file

@ -0,0 +1,268 @@
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
# import .backbones.common.dist_utils as dist_utils
# from minigpt4.common.dist_utils import download_cached_file
# from minigpt4.common.utils import is_url
# from minigpt4.common.logger import MetricLogger
from models.backbones.base_model import BaseModel
from models.backbones.Qformer import BertConfig, BertLMHeadModel
from models.backbones.eva_vit import create_eva_vit_g
from transformers import BertTokenizer
class Blip2Base(BaseModel):
@classmethod
def init_tokenizer(cls):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
):
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
return visual_encoder, ln_vision
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def get_optimizer_params(self, weight_decay, lr_scale=1):
vit_num_layers = self.visual_encoder.get_num_layer()
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
parameter_group_names = {}
parameter_group_vars = {}
for name, param in self.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias"):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if 'visual_encoder' in name:
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if layer_id is not None:
scale = lr_scales[layer_id]
else:
scale = 1
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
# import json
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
optim_params = list(parameter_group_vars.values())
return optim_params
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def compute_sim_matrix(model, data_loader, **kwargs):
k_test = kwargs.pop("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = model.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(model.device)
text_feat = model.forward_text(text_input)
text_embed = F.normalize(model.text_proj(text_feat))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
vit_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = image.to(model.device)
image_feat, vit_feat = model.forward_image(image)
image_embed = model.vision_proj(image_feat)
image_embed = F.normalize(image_embed, dim=-1)
vit_feats.append(vit_feat.cpu())
image_embeds.append(image_embed)
vit_feats = torch.cat(vit_feats, dim=0)
image_embeds = torch.cat(image_embeds, dim=0)
sims_matrix = []
for image_embed in image_embeds:
sim_q2t = image_embed @ text_embeds.t()
sim_i2t, _ = sim_q2t.max(0)
sims_matrix.append(sim_i2t)
sims_matrix = torch.stack(sims_matrix, dim=0)
score_matrix_i2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(model.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[topk_idx],
text_atts=text_atts[topk_idx],
).float()
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(model.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[start + i].repeat(k_test, 1),
text_atts=text_atts[start + i].repeat(k_test, 1),
).float()
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()

110
models/backbones/blip2_outputs.py Executable file
View file

@ -0,0 +1,110 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_outputs import (
ModelOutput,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
@dataclass
class BlipSimilarity(ModelOutput):
sim_i2t: torch.FloatTensor = None
sim_t2i: torch.FloatTensor = None
sim_i2t_m: Optional[torch.FloatTensor] = None
sim_t2i_m: Optional[torch.FloatTensor] = None
sim_i2t_targets: Optional[torch.FloatTensor] = None
sim_t2i_targets: Optional[torch.FloatTensor] = None
@dataclass
class BlipIntermediateOutput(ModelOutput):
"""
Data class for intermediate outputs of BLIP models.
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
decoder_labels (torch.LongTensor): labels for the captioning loss.
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
"""
# uni-modal features
image_embeds: torch.FloatTensor = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds_m: Optional[torch.FloatTensor] = None
text_embeds_m: Optional[torch.FloatTensor] = None
# intermediate outputs of multimodal encoder
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
itm_logits: Optional[torch.FloatTensor] = None
itm_labels: Optional[torch.LongTensor] = None
# intermediate outputs of multimodal decoder
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
decoder_labels: Optional[torch.LongTensor] = None
@dataclass
class BlipOutput(ModelOutput):
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
sims: Optional[BlipSimilarity] = None
intermediate_output: BlipIntermediateOutput = None
loss: Optional[torch.FloatTensor] = None
loss_itc: Optional[torch.FloatTensor] = None
loss_itm: Optional[torch.FloatTensor] = None
loss_lm: Optional[torch.FloatTensor] = None
@dataclass
class BlipOutputFeatures(ModelOutput):
"""
Data class of features from BlipFeatureExtractor.
Args:
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
The first embedding or feature is for the [CLS] token.
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
"""
image_embeds: Optional[torch.FloatTensor] = None
image_embeds_proj: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
text_embeds_proj: Optional[torch.FloatTensor] = None
multimodal_embeds: Optional[torch.FloatTensor] = None

View file

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionEncoder(nn.Module):
def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_encoder_name = encoder_name
# self.select_layer = args.mm_vision_select_layer
# self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.select_layer = -1
self.select_feature = "patch"
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
self.vision_encoder.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, :]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
# print("image feature shape", image_features.shape)
# print(type(image_forward_outs))
# print(type(image_forward_outs.shape))
# image_features = image_forward_outs.to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_encoder.dtype
@property
def device(self):
return self.vision_encoder.device
@property
def config(self):
if self.is_loaded:
return self.vision_encoder.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2

View file

@ -0,0 +1,141 @@
import glog as logger
import re
import json
from peft import LoraConfig, get_peft_model
from .xflan_t5 import T5Config, T5ForConditionalGeneration
from .xbart import BartConfig, BartForConditionalGeneration, BartEncoder, BartForCausalLM
def build_encoder_decoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Encoder Decoder [Type = {}]'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
if model_config['use_decoder_only']:
model_cls = BartForCausalLM
else:
model_cls = BartForConditionalGeneration
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
enc_dec_config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = enc_dec_config.d_model
# enc_dec_config.encoder_layers = enc_dec_config.encoder_layers - model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
enc_dec = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=enc_dec_config
)
# first_k = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
# enc_dec.model.encoder.remove_first_k_layers(first_k)
# get the last encoder layers
# enc_dec.
if model_config['use_lora_enc_dec']:
# load the lora config
with open(model_config['lora_config'], 'r') as f:
lora_config = json.load(f)
# get the linear layer to perform LoRA on
model_modules = str(enc_dec.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)
names = []
# Print the names of the Linear layers
for name in linear_layer_names:
names.append(name)
target_modules = list(set(names))
lora_config['target_modules'] = target_modules
lora_config = LoraConfig(**lora_config)
enc_dec = get_peft_model(enc_dec, lora_config)
return enc_dec
def build_encoder(model_config, expert_type, modality=None):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
log_txt = '[INFO] Loading {} Expert'.format(expert_type)
if modality is not None:
log_txt += ' [Modality = {}]'.format(modality)
log_txt += ' [Type = {}]'.format(model_config['enc_dec_name'])
logger.info(log_txt)
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartEncoder
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
config.modality_expert_layers = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
config.grounding_expert_layers = model_config['num_layers_grounding_expert_{}'.format(model_config['enc_dec_family'])]
model_config['enc_dec_dim'] = config.d_model
expert = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config,
expert_type=expert_type,
modality=modality
)
if model_config['use_lora_expert']:
# load the lora config
with open(model_config['lora_config'], 'r') as f:
lora_config = json.load(f)
# get the linear layer to perform LoRA on
model_modules = str(expert.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)
names = []
# Print the names of the Linear layers
for name in linear_layer_names:
names.append(name)
target_modules = list(set(names))
lora_config['target_modules'] = target_modules
lora_config = LoraConfig(**lora_config)
expert = get_peft_model(expert, lora_config)
# expert = model_cls(
# config=config,
# expert_type=expert_type,
# modality=modality
# )
return expert

View file

@ -0,0 +1,65 @@
from .xflan_t5 import T5Config, T5ForConditionalGeneration
from .xbart_original import BartConfig, BartForConditionalGeneration, BartEncoder
import glog as logger
def build_encoder_decoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Encoder Decoder: {}'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartForConditionalGeneration
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = config.d_model
enc_dec = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config
)
return enc_dec
def build_encoder(model_config):
"""build (encoder-) decoder model for answer generation.
Args:
model_config (dict): model config.
Returns: TODO
"""
logger.info('[INFO] Loading Expert as Encoder of {}'.format(model_config['enc_dec_name']))
if model_config['enc_dec_family'] == 'flan_t5':
config_cls = T5Config
model_cls = T5ForConditionalGeneration
elif model_config['enc_dec_family'] == 'bart':
config_cls = BartConfig
model_cls = BartEncoder
else:
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
config = config_cls.from_pretrained(model_config['enc_dec_name'])
model_config['enc_dec_dim'] = config.d_model
config.encoder_layers = model_config['num_layers_modality_expert']
expert = model_cls.from_pretrained(
model_config['enc_dec_name'],
config=config
)
return expert

View file

@ -0,0 +1,19 @@
from typing import Optional, Tuple
import torch
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
@dataclass
class Seq2SeqV2DialOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

455
models/backbones/eva_vit.py Executable file
View file

@ -0,0 +1,455 @@
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from models.common.dist_utils import download_cached_file
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def get_num_layer(self, var_name=""):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("rel_pos_bias"):
return len(self.blocks) - 1
elif var_name.startswith("blocks"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return len(self.blocks)
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
# depth = 37,
num_heads=1408//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
cached_file = download_cached_file(
url, check_hash=False, progress=True
)
state_dict = torch.load(cached_file, map_location="cpu")
interpolate_pos_embed(model,state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
# print(incompatible_keys)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model

View file

@ -0,0 +1,895 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
from transformers import LlamaTokenizer
from transformers import BitsAndBytesConfig
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
import time
import numpy as np
from minigpt4.models import policies
@registry.register_model("mini_gpt4_llama_v2")
class MiniGPT4_llama_v2(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
lora_r = 8,
lora_target_modules = ["q_proj","v_proj"],
lora_alpha=16,
# lora_r = 16,
# lora_target_modules = ["q_proj","v_proj","v_proj"],
lora_dropout= 0.05,
ckpt_path = "",
system_prompt= False,
chat_template=False,
token_pooling=True,
use_grad_checkpoint_llm=False,
max_context_len=3800,
remove_template = False,
):
super().__init__()
if "Mistral" in llama_model:
from minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
print("Mistral model")
self.model_type = "Mistral"
else:
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
print("Llama model")
self.model_type = "Llama"
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
self.token_pooling = token_pooling
self.remove_template = remove_template
print("token pooling", self.token_pooling)
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
self.max_context_len = max_context_len
self.chat_template = chat_template
# print('Loading VIT')
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
# )
if freeze_vit:
# vit_precision="fp32"
print("vit precision", vit_precision)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
logging.info("freeze vision encoder")
print("freeze the vision encoder")
else:
vit_precision="fp32"
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
print("unfreeze the vision encoder")
print('Loading VIT Done')
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
# assert False
print('Loading LLAMA')
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model,use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
self.system_prompt = system_prompt
print("self.low_resource",self.low_resource)
if self.low_resource:
self.llama_model = llm_model.from_pretrained(
llama_model,
torch_dtype=torch.float16,
# torch_dtype = torch.bfloat16,
load_in_8bit=True,
# device_map = "balanced"
# device_map="auto",
device_map={'':torch.cuda.current_device()},
# device_map={'':0}
)
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16,
# )
# self.llama_model = llm_model.from_pretrained(
# llama_model,
# torch_dtype=torch.bfloat16,
# device_map={'':torch.cuda.current_device()},
# quantization_config=bnb_config,
# )
else:
self.llama_model = llm_model.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)
# self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
self.llama_model = prepare_model_for_int8_training(self.llama_model)
loraconfig = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.llama_model = get_peft_model(self.llama_model, loraconfig)
# if ckpt_path:
# print('load the llm under lora')
# ckpt = torch.load(ckpt_path)
# set_peft_model_state_dict(self.llama_model,ckpt)
self.llama_model.print_trainable_parameters()
if self.use_grad_checkpoint_llm:
self.llama_model.gradient_checkpointing_enable()
# if not self.low_resource:
# for name, param in self.llama_model.named_parameters():
# if "embed_token" in name:
# param.data = param.data.float()
# param.requires_grad = True
print('Loading LLAMA Done')
if self.token_pooling:
self.llama_proj = nn.Linear(
1408*4, self.llama_model.config.hidden_size
)
else:
self.llama_proj = nn.Linear(
1408, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
bs, pn, hs = image_embeds.shape
if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def get_context_emb(self, prompt, img_list):
img_device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
# # truncate the length of tokens to the max context window
# mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair]
# mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1)
# # check if the number of token in the second dimention is more than the max context window then truncate it
# context_window=self.max_context_len-seg_embs[-1].shape[1]
# if mixed_embs_without_instruction.shape[1] > context_window :
# mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window]
# mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1)
# print("mixed_embs",mixed_embs.shape)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
# print("p_embed device",p_tokens.input_ids.device)
# print("p_tokens",img_embeds.device)
# print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device)
p_embed = self.embed_tokens(p_tokens.input_ids)
# print("model device",self.llama_model.get_device())
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
# max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
max_length = self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
# print('===================================')
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
# print('check out emb: ', output_embs[i][:2])
# print('check out pad emb: ', output_embs[i][-2:])
# print('+++++++++++++++++++++++++++++++++++')
#
# print('check attn before: ', input_atts[i][:this_input_ones])
# print('check attn after: ', input_atts[i][this_input_ones:])
# print('check attn gt before: ', output_atts[i][:3])
# print('check attn gt after: ', output_atts[i][-3:])
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def get_conv_emb(self, conv_q, conv_a, conv_img):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
regress_embs_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
assigned_imgs = conv_img[batch_idx]
questions = [self.prompt_wrap(
img_embeds=img,
atts_img=None,
prompts=[q],
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
q_embs = [emb for emb, _ in questions]
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
cur_emb = []
cur_target = []
for i in range(len(questions)):
cur_emb.append(q_embs[i])
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
cur_emb.append(self.embed_tokens(answers[i].input_ids))
cur_target.append(answers[i].input_ids)
cur_emb = torch.cat(cur_emb, dim=1)
cur_target = torch.cat(cur_target, dim=1)
regress_embs_list.append(cur_emb)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = regress_embs_list[batch_idx].shape[1]
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
regress_attn[batch_idx, :cur_len] = 1
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
return regress_embeds, regress_attn, targets
def preparing_embedding(self, samples):
def remove_special_tokens(data):
# if "instruction_input" in data:
data = [instruct.replace(" [caption]","") for instruct in data]
data = [instruct.replace(" [vqa]","") for instruct in data]
data = [instruct.replace(" [grounding]","") for instruct in data]
data = [instruct.replace(" [identify]","") for instruct in data]
data = [instruct.replace(" [refer]","") for instruct in data]
return data
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
# print("img_embeds shape",img_embeds.shape)
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_img = assign_imgs(conv_q, img_embeds)
if self.chat_template:
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
# print("instruction before", instruction)
if self.remove_template:
instruction = remove_special_tokens(instruction)
# print("instruction after", instruction)
if self.chat_template:
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction="mean"):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
print("inputs_embeds shape",inputs_embeds.shape)
print("cond_embeds shape",cond_embeds.shape)
print("regress_embeds shape",regress_embeds.shape)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = attention_mask[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# print length of instruction_input and answer words
# for i in range (len(samples["instruction_input"])):
# print("instruction_input length",len(samples["instruction_input"][i].split(" ")))
# print("answer length",len(samples["answer"][i].split(" ")))
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
print("targets shape",targets.shape)
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
images,
texts,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
return_video_temporal_features=False,
img_embeds=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
if img_embeds is None:
img_embeds, atts_img = self.encode_img(images.to(self.device))
else:
# Use images features from the input(4,45,64,5632)
img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
img_embeds= img_embeds.to(self.device)
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
print("img_embeds shape",img_embeds.shape)
if lengths is not None:
image_lists = []
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
for idx, img_embed in enumerate(img_embeds):
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
else:
image_lists = [[image_emb[None]] for image_emb in img_embeds]
assert len(texts) == len(image_lists)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
# print("inputs_embeds shape",embs.shape)
# print("attention_mask shape",attn_mask.shape)
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
if self.model_type == "Llama":
context_window = 3700
else:
context_window = 7500
if embs.shape[1] > context_window:
embs = embs[:, -context_window:]
attn_mask = attn_mask[:, -context_window:]
print("inputs_embeds shape",embs.shape)
print("attention_mask shape",attn_mask.shape)
with self.maybe_autocast():
if return_video_temporal_features:
last_hidden_state = self.llama_model(
inputs_embeds=embs,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[-1]
video_temporal_features = last_hidden_state.mean(dim=1)
# normalize the temporal features using L2 norm
# video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
if return_video_temporal_features:
return answers, video_temporal_features
else:
return answers
@torch.no_grad()
def generate_text_only(
self,
images,
seg_tokens,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
return_video_temporal_features=False,
img_embeds=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# seg_tokens=[]
# for i, text in enumerate(texts):
# seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids)
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
# seg_embs = torch.cat(seg_embs, dim=1)
# print("seg_embs shape",seg_embs.shape)
# batch_embs=[seg_embs]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
print("inputs_embeds shape",embs.shape)
print("attention_mask shape",attn_mask.shape)
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
def predict_answers(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=128,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for open-ended VQA
'''
images = samples["image"].cuda()
texts = samples["instruction_input"]
output_text = self.generate(
images=images,
texts=texts,
num_beams=num_beams,
max_new_tokens=max_len,
min_length=min_len,
length_penalty=length_penalty
)
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
output_text = self._lemmatize(output_text)
return output_text
def predict_class(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=5,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for multi-choice VQA
'''
image = samples["image"].cuda()
instruction = samples['instruction_input']
answers = samples["choices"]
num_cand = samples["num_choices"]
ranks = self.multi_select(image, instruction, answers, num_cand)
pred_ans = []
for i, rank in enumerate(ranks):
pred = answers[rank[0]][i]
pred_ans.append(pred)
return pred_ans
def embed_tokens(self, token_ids):
try:
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
except AttributeError:
embeds = self.llama_model.model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r",64)
lora_alpha = cfg.get("lora_alpha",16)
chat_template = cfg.get("chat_template",False)
system_prompt = cfg.get("system_prompt", False)
token_pooling = cfg.get("token_pooling",True)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
remove_template = cfg.get("remove_template", False)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r = lora_r,
lora_alpha = lora_alpha,
chat_template = chat_template,
system_prompt = system_prompt,
token_pooling = token_pooling,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
remove_template = remove_template
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
def assign_imgs(batched_instruct_list, batched_img_embeds):
'''this function is used when the data is interleaved.
the interlevaed data is separated, and this function assign
corresponding image embeddings to each segment'''
if len(batched_img_embeds.shape) == 3:
batched_img_embeds = batched_img_embeds[:, None]
batched_assigned = []
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
img_idx = 0
assigned_img = []
n_assigned = []
for instruct in instruct_list:
n_img = instruct.count('<ImageHere>')
if n_img > 0: # this instruction include images.
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
img_idx += n_img
n_assigned.append(n_img)
else: # this instruction doesn't include images
assigned_img.append(None)
n_assigned.append(None)
batched_assigned.append(assigned_img)
return batched_assigned

709
models/backbones/mini_gpt4v.py Executable file
View file

@ -0,0 +1,709 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training
)
import time
import numpy as np
from minigpt4.models import policies
@registry.register_model("mini_gpt4v")
class MiniGPT4v(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
lora_r = 8,
lora_target_modules = ["q_proj","v_proj"],
lora_alpha=16,
# lora_r = 16,
# lora_target_modules = ["q_proj","v_proj","v_proj"],
lora_dropout= 0.05,
ckpt_path = "",
system_prompt= False,
chat_template=False,
token_pooling=True,
use_grad_checkpoint_llm=False,
max_context_len=3800,
remove_template = False,
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
self.token_pooling = token_pooling
self.remove_template = remove_template
print("token pooling", self.token_pooling)
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
self.max_context_len = max_context_len
self.chat_template = chat_template
# print('Loading VIT')
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
# )
print("vit precision", vit_precision)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
)
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
logging.info("freeze vision encoder")
print("freeze the vision encoder")
print('Loading VIT Done')
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
# assert False
print('Loading LLAMA')
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
if 'CodeLlama' in llama_model:
self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
else:
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
self.llama_tokenizer.pad_token = "$$"
self.system_prompt = system_prompt
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
quantization_config=bnb_config,
device_map={"": 0}
)
# self.llama_model.gradient_checkpointing_enable()
self.llama_model = prepare_model_for_kbit_training(self.llama_model)
# self.llama_model.print_trainable_parameters()
print('Loading LLAMA Done')
self.merge_n = 3
self.llama_proj = nn.Linear(
1408 * self.merge_n**2, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
bs, ch, w, h = image.shape
assert w % 224 == 0
bw = w // 224
assert h % 224 == 0
bh = h // 224
image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224
image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)
with self.maybe_autocast():
image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)
image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs
image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def get_context_emb(self, prompt, img_list):
img_device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
# print('===================================')
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
# print('check out emb: ', output_embs[i][:2])
# print('check out pad emb: ', output_embs[i][-2:])
# print('+++++++++++++++++++++++++++++++++++')
#
# print('check attn before: ', input_atts[i][:this_input_ones])
# print('check attn after: ', input_atts[i][this_input_ones:])
# print('check attn gt before: ', output_atts[i][:3])
# print('check attn gt after: ', output_atts[i][-3:])
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def get_conv_emb(self, conv_q, conv_a, conv_img):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
regress_embs_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
assigned_imgs = conv_img[batch_idx]
questions = [self.prompt_wrap(
img_embeds=img,
atts_img=None,
prompts=[q],
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
q_embs = [emb for emb, _ in questions]
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
cur_emb = []
cur_target = []
for i in range(len(questions)):
cur_emb.append(q_embs[i])
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
cur_emb.append(self.embed_tokens(answers[i].input_ids))
cur_target.append(answers[i].input_ids)
cur_emb = torch.cat(cur_emb, dim=1)
cur_target = torch.cat(cur_target, dim=1)
regress_embs_list.append(cur_emb)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = regress_embs_list[batch_idx].shape[1]
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
regress_attn[batch_idx, :cur_len] = 1
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
return regress_embeds, regress_attn, targets
def preparing_embedding(self, samples):
def remove_special_tokens(data):
# if "instruction_input" in data:
data = [instruct.replace(" [caption]","") for instruct in data]
data = [instruct.replace(" [vqa]","") for instruct in data]
data = [instruct.replace(" [grounding]","") for instruct in data]
data = [instruct.replace(" [identify]","") for instruct in data]
data = [instruct.replace(" [refer]","") for instruct in data]
return data
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_img = assign_imgs(conv_q, img_embeds)
if self.chat_template:
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
# print("instruction before", instruction)
if self.remove_template:
instruction = remove_special_tokens(instruction)
# print("instruction after", instruction)
if self.chat_template:
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction="mean"):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = attention_mask[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
images,
texts,
use_nucleus_sampling=False,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
lengths=None,
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
img_embeds, atts_img = self.encode_img(images.to(self.device))
if lengths is not None:
image_lists = []
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
for idx, img_embed in enumerate(img_embeds):
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
else:
image_lists = [[image_emb[None]] for image_emb in img_embeds]
assert len(texts) == len(image_lists)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
def predict_answers(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=128,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for open-ended VQA
'''
images = samples["image"].cuda()
texts = samples["instruction_input"]
output_text = self.generate(
images=images,
texts=texts,
num_beams=num_beams,
max_new_tokens=max_len,
min_length=min_len,
length_penalty=length_penalty
)
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
output_text = self._lemmatize(output_text)
return output_text
def predict_class(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=5,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
'''
function for multi-choice VQA
'''
image = samples["image"].cuda()
instruction = samples['instruction_input']
answers = samples["choices"]
num_cand = samples["num_choices"]
ranks = self.multi_select(image, instruction, answers, num_cand)
pred_ans = []
for i, rank in enumerate(ranks):
pred = answers[rank[0]][i]
pred_ans.append(pred)
return pred_ans
def embed_tokens(self, token_ids):
try:
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
except AttributeError:
embeds = self.llama_model.model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r",64)
lora_alpha = cfg.get("lora_alpha",16)
chat_template = cfg.get("chat_template",False)
system_prompt = cfg.get("system_prompt", False)
token_pooling = cfg.get("token_pooling",True)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
remove_template = cfg.get("remove_template", False)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r = lora_r,
lora_alpha = lora_alpha,
chat_template = chat_template,
system_prompt = system_prompt,
token_pooling = token_pooling,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
remove_template = remove_template
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
def assign_imgs(batched_instruct_list, batched_img_embeds):
'''this function is used when the data is interleaved.
the interlevaed data is separated, and this function assign
corresponding image embeddings to each segment'''
if len(batched_img_embeds.shape) == 3:
batched_img_embeds = batched_img_embeds[:, None]
batched_assigned = []
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
img_idx = 0
assigned_img = []
n_assigned = []
for instruct in instruct_list:
n_img = instruct.count('<ImageHere>')
if n_img > 0: # this instruction include images.
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
img_idx += n_img
n_assigned.append(n_img)
else: # this instruction doesn't include images
assigned_img.append(None)
n_assigned.append(None)
batched_assigned.append(assigned_img)
return batched_assigned

View file

@ -0,0 +1,25 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
p="Well, I'm quite partial to a good squeeze of fresh lemon juice."
encoded_input = tokenizer(p, return_tensors='pt')
embeds = model.model.embed_tokens(encoded_input.input_ids)
print(embeds.shape)
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])

View file

@ -0,0 +1,112 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View file

@ -0,0 +1,112 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

File diff suppressed because it is too large Load diff

287
models/backbones/moes.py Normal file
View file

@ -0,0 +1,287 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
# if mask.dim() != x.dim():
# expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
# else:
# expanded_mask = mask
mask = mask.bool()
attn = attn.masked_fill(~mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class MoELayer(nn.Module):
def __init__(
self,
dim,
num_heads,
expert_type,
use_sep_spatial_temp_experts=True,
has_hist=False,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LlamaRMSNorm,
):
super().__init__()
self.has_hist = has_hist
self.use_sep_spatial_temp_experts = use_sep_spatial_temp_experts
self.norm_att = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
mlp_hidden_dim = int(dim * mlp_ratio)
if expert_type == 'modalities':
# EXPERT CONSTRUCTION
if use_sep_spatial_temp_experts:
# Spatial expert
self.norm_spatial = norm_layer(dim)
self.mlp_spatial = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Temporal expert
self.norm_temp = norm_layer(dim)
self.mlp_temp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Vis expert
self.norm_vis = norm_layer(dim)
self.mlp_vis = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# caption expert
self.norm_cap = norm_layer(dim)
self.mlp_cap = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# history expert
if has_hist:
self.norm_hist = norm_layer(dim)
self.mlp_hist = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
elif expert_type == 'fusion':
# Fusion expert
self.norm_fusion = norm_layer(dim)
self.mlp_fusion = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
else:
raise ValueError
def forward(self, x, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len=None, is_vid=False, mask=None, only_text=False, expert_permutation=None):
if self.has_hist:
assert hist_feat_len is not None
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
x = x + self.drop_path(x_shortcut)
len_init = x.size(1)
# bs, h_dim = x.size(0), x.size(-1)
# device = x.device
# if only_text:
# # end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
# # x = x[:, special_toks_indices['<caption>']: end_idx_caption, :]
# x = x + self.drop_path(self.mlp_cap(self.norm_cap(x)))
if expert_flag == 'modalities':
if self.use_sep_spatial_temp_experts:
x_spatial = x[:, :vis_feat_len]
if expert_permutation is not None:
if expert_permutation['spatial'] == 'temporal':
x_spatial = x_spatial + self.drop_path(self.mlp_temp(self.norm_temp(x_spatial)))
elif expert_permutation['spatial'] == 'caption':
x_spatial = x_spatial + self.drop_path(self.mlp_cap(self.norm_cap(x_spatial)))
elif expert_permutation['spatial'] == 'history':
x_spatial = x_spatial + self.drop_path(self.mlp_hist(self.norm_hist(x_spatial)))
elif expert_permutation['spatial'] == 'spatial':
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
x_vis = x_spatial
else:
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
x_vis = x_spatial
if is_vid:
x_temporal = x[:, vis_feat_len:2*vis_feat_len]
if expert_permutation is not None:
if expert_permutation['temporal'] == 'spatial':
x_temporal = x_temporal + self.drop_path(self.mlp_spatial(self.norm_spatial(x_temporal)))
elif expert_permutation['temporal'] == 'caption':
x_temporal = x_temporal + self.drop_path(self.mlp_cap(self.norm_cap(x_temporal)))
elif expert_permutation['temporal'] == 'history':
x_temporal = x_temporal + self.drop_path(self.mlp_hist(self.norm_hist(x_temporal)))
elif expert_permutation['temporal'] == 'temporal':
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
else:
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
x_vis = torch.concat([x_spatial, x_temporal], dim=1)
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
else:
x_vis = x[:, :vis_feat_len]
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
if self.has_hist:
x_caption = x[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
if expert_permutation is not None:
if expert_permutation['caption'] == 'spatial':
x_caption = x_caption + self.drop_path(self.mlp_spatial(self.norm_spatial(x_caption)))
elif expert_permutation['caption'] == 'temporal':
x_caption = x_caption + self.drop_path(self.mlp_temp(self.norm_temp(x_caption)))
elif expert_permutation['caption'] == 'history':
x_caption = x_caption + self.drop_path(self.mlp_hist(self.norm_hist(x_caption)))
elif expert_permutation['caption'] == 'caption':
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
else:
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
x_history = x[:, -hist_feat_len:]
if expert_permutation is not None:
if expert_permutation['history'] == 'spatial':
x_history = x_history + self.drop_path(self.mlp_spatial(self.norm_spatial(x_history)))
elif expert_permutation['history'] == 'temporal':
x_history = x_history + self.drop_path(self.mlp_temp(self.norm_temp(x_history)))
elif expert_permutation['history'] == 'caption':
x_history = x_history + self.drop_path(self.mlp_cap(self.norm_cap(x_history)))
elif expert_permutation['history'] == 'history':
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
else:
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
# concat the features back
x = torch.cat([x_vis, x_caption, x_history], dim=1)
else:
x_caption = x[:, -cap_feat_len:]
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
x = torch.cat([x_vis, x_caption], dim=1)
assert x.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
x.size(1), len_init
)
elif expert_flag == 'fusion':
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
return x
class Pooler(nn.Module):
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
pooled_states = hidden_states[:, 0]
pooled_output = self.dense(pooled_states)
pooled_output = self.activation(pooled_output)
return pooled_output

View file

@ -0,0 +1,234 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
import warnings
from torch import Tensor
from typing import Optional, Tuple
from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig
class MoELayer(nn.Module):
def __init__(self, config, expert_type):
super(MoELayer, self).__init__()
self.config = config
self.expert_type = expert_type
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
# Shared across all experts
self.attention = BertAttention(self.bert_config)
# One for each expert
if expert_type == 'modalities':
# Spatial expert
self.intermediate_spatial = BertIntermediate(self.bert_config)
self.output_spatial = BertOutput(self.bert_config)
# Temporal expert
self.intermediate_temporal = BertIntermediate(self.bert_config)
self.output_temporal = BertOutput(self.bert_config)
# Vis Expert
self.intermediate_vis = BertIntermediate(self.bert_config)
self.output_vis = BertOutput(self.bert_config)
# Caption Expert
self.intermediate_caption = BertIntermediate(self.bert_config)
self.output_caption = BertOutput(self.bert_config)
if config.stage != 'stage_1':
# History Expert
self.intermediate_history = BertIntermediate(self.bert_config)
self.output_history = BertOutput(self.bert_config)
# Fusion expert
elif expert_type == 'fusion':
self.intermediate_fusion = BertIntermediate(self.bert_config)
self.output_fusion = BertOutput(self.bert_config)
else:
raise ValueError
self._init_weights()
def _init_weights(self):
for _, m in dict(self.named_modules()).items():
if isinstance(m, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.bert_config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
# if self.config.is_decoder:
# extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
# input_shape, attention_mask, device
# )
# else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False):
input_shape = hidden_states.size()[:-1]
# dtype = mask.dtype
# device = mask.device
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32)
self_attention_outputs = self.attention(
hidden_states,
attention_mask=extended_attention_mask,
output_attentions=output_attentions,
head_mask=None
)
attention_output = self_attention_outputs[0]
# outputs = self_attention_outputs[1:]
len_init = attention_output.size(1)
# bs, h_dim = x.size(0), x.size(-1)
# device = x.device
if expert_flag == 'modalities':
if only_text:
intermediate_output = self.intermediate_caption(attention_output)
layer_output = self.output_caption(intermediate_output, attention_output)
else:
# split the input first into different parts/modalities
unchanged = attention_output[:, :special_toks_indices['<vis>'], :]
end_idx_spatial = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
attention_spatial = attention_output[:, special_toks_indices['<vis>']:end_idx_spatial, :]
end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
attention_caption = attention_output[:, special_toks_indices['<caption>']: end_idx_caption, :]
attention_temporal, attention_history = None, None
if '<temporal>' in special_toks_indices:
end_idx_temporal = special_toks_indices['<caption>']
attention_temporal = attention_output[:, special_toks_indices['<temporal>']:end_idx_temporal, :]
if '<history>' in special_toks_indices:
end_idx_history = special_toks_indices['</s>'] + 1
attention_history = attention_output[:, special_toks_indices['<history>']:end_idx_history, :]
# Expert activation
# 1- Spatial
intermediate_spatial = self.intermediate_spatial(attention_spatial)
output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial)
output_vis = output_sapatial
# 2- Temporal
if attention_temporal is not None:
intermediate_temporal = self.intermediate_temporal(attention_temporal)
output_temporal = self.output_temporal(intermediate_temporal, attention_temporal)
attention_vis = torch.concat([output_sapatial, output_temporal], dim=1)
intermediate_vis = self.intermediate_vis(attention_vis)
output_vis = self.output_vis(intermediate_vis, attention_vis)
# 3- Caption
intermediate_caption = self.intermediate_caption(attention_caption)
output_caption = self.output_caption(intermediate_caption, attention_caption)
# 4- History
if attention_history is not None:
intermediate_history = self.intermediate_history(attention_history)
output_history = self.output_history(intermediate_history, attention_history)
output_list = [unchanged, output_vis, output_caption]
if attention_history is not None:
output_list.append(output_history)
# Concat the features back
layer_output = torch.concat(output_list, dim=1)
assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
layer_output.size(1), len_init
)
elif expert_flag == 'fusion':
intermediate_output = self.intermediate_fusion(attention_output)
layer_output = self.output_fusion(intermediate_output, attention_output)
return layer_output
class MoEPooler(nn.Module):
def __init__(self):
super(MoEPooler, self).__init__()
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
hidden_size = self.bert_config.hidden_size
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
self._init_weights()
def _init_weights(self):
for _, m in dict(self.named_modules()).items():
if isinstance(m, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def forward(self, hidden_states, idx):
pooled_states = hidden_states[:, idx]
pooled_output = self.dense(pooled_states)
pooled_output = self.activation(pooled_output)
return pooled_output

View file

@ -0,0 +1,247 @@
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
if mask.dim() != x.dim():
expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
else:
expanded_mask = mask
expanded_mask = expanded_mask.bool()
attn = attn.masked_fill(~expanded_mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class MoELayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.SiLU,
norm_layer=LlamaRMSNorm,
):
super().__init__()
self.norm_att = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# EXPERT CONSTRUCTION
mlp_hidden_dim = int(dim * mlp_ratio)
# Spatial expert
self.norm_spatial = norm_layer(dim)
self.mlp_spatial = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Temporal expert
self.norm_temp = norm_layer(dim)
self.mlp_temp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Vis expert
self.norm_vis = norm_layer(dim)
self.mlp_vis = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# caption expert
self.norm_cap = norm_layer(dim)
self.mlp_cap = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# history expert
self.norm_hist = norm_layer(dim)
self.mlp_hist = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# Fusion expert
self.norm_fusion = norm_layer(dim)
self.mlp_fusion = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# expert_flag:{Only Text : 00 , Only Image : 01, Fusion : 10, Text & Image : 11} (BINARY)
# expert_flag:
# 0:
def forward(self, x, special_toks_indices, expert_flag, mask=None):
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
x = x + self.drop_path(x_shortcut)
bs, h_dim = x.size(0), x.size(-1)
device = x.device
if expert_flag == 'modalities':
end_index = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
spatial_feats = x[:, special_toks_indices['<spatial>']: end_index, :]
spatial_feats = spatial_feats + self.drop_path(self.mlp_spatial(self.norm_spatial(spatial_feats)))
spatial_index = torch.arange(special_toks_indices['<spatial>'], end_index, device=device)
spatial_index = spatial_index.unsqueeze(0).unsqueeze(-1)
spatial_index = spatial_index.repeat(bs, 1, h_dim)
x = x.scatter(1, spatial_index, spatial_feats)
# x[:, special_toks_indices['<spatial>']: special_toks_indices['<temporal>'], :] = spatial_feats
end_index = special_toks_indices.get('<history>', special_toks_indices['</s>'])
caption_feats = x[:, special_toks_indices['<caption>']: end_index, :]
caption_feats = caption_feats + self.drop_path(self.mlp_cap(self.norm_cap(caption_feats)))
caption_index = torch.arange(special_toks_indices['<caption>'], end_index, device=device)
caption_index = caption_index.unsqueeze(0).unsqueeze(-1)
caption_index = caption_index.repeat(bs, 1, h_dim)
x = x.scatter(1, caption_index, caption_feats)
# x[:, special_toks_indices['<caption>']: special_toks_indices['</s>'], :] = caption_feats
if '<temporal>' in special_toks_indices:
temporal_feats = x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :]
temporal_feats = temporal_feats + self.drop_path(self.mlp_temp(self.norm_temp(temporal_feats)))
temporal_index = torch.arange(special_toks_indices['<temporal>'], special_toks_indices['<caption>'], device=device)
temporal_index = temporal_index.unsqueeze(0).unsqueeze(-1)
temporal_index = temporal_index.repeat(bs, 1, h_dim)
x = x.scatter(1, temporal_index, temporal_feats)
# x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :] = temporal_feats
vis_feats = x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :]
vis_feats = vis_feats + self.drop_path(self.mlp_vis(self.norm_vis(vis_feats)))
vis_index = torch.arange(special_toks_indices['<vis>'], special_toks_indices['<caption>'], device=device)
vis_index = vis_index.unsqueeze(0).unsqueeze(-1)
vis_index = vis_index.repeat(bs, 1, h_dim)
x = x.scatter(1, vis_index, vis_feats)
# x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :] = vis_feats
if '<history>' in special_toks_indices:
history_feats = x[:, special_toks_indices['<history>']: special_toks_indices['</s>'], :]
history_feats = history_feats + self.drop_path(self.mlp_hist(self.norm_hist(history_feats)))
history_index = torch.arange(special_toks_indices['<history>'], special_toks_indices['</s>'], device=device)
history_index = history_index.unsqueeze(0).unsqueeze(-1)
history_index = history_index.repeat(bs, 1, h_dim)
x = x.scatter(1, history_index, history_feats)
elif expert_flag == 'fusion':
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
return x, attn
# if expert_flag == 2:
# x = x + self.drop_path(self.mlp(self.norm2(x)))
# elif expert_flag == 0:
# x = (x[:, -it_split:])
# x = x + self.drop_path(self.sentence_mlp(self.sentence_norm(x)))
# elif expert_flag == 1:
# x = (x[:, :-it_split ])
# x = x + self.drop_path(self.image_mlp(self.image_norm(x)))
# elif expert_flag == 3:
# text, image = (x[:, :it_split], x[:, it_split:],)
# text = text + self.drop_path(self.sentence_mlp(self.sentence_norm(text)))
# image = image + self.drop_path(self.image_mlp(self.image_norm(image)))
# x = torch.cat([text, image], dim=1)
# elif expert_flag == 4:
# x = x + self.drop_path(self.generation_mlp(self.generation_norm(x)))
# return x, attn