initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
1216
models/backbones/Qformer.py
Executable file
1216
models/backbones/Qformer.py
Executable file
File diff suppressed because it is too large
Load diff
0
models/backbones/__init__.py
Normal file
0
models/backbones/__init__.py
Normal file
247
models/backbones/base_model.py
Executable file
247
models/backbones/base_model.py
Executable file
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||
from models.common.utils import get_abs_path, is_url
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""Base class for models."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
def load_checkpoint(self, url_or_filename):
|
||||
"""
|
||||
Load from a finetuned checkpoint.
|
||||
|
||||
This should expect no mismatch in the model keys and the checkpoint keys.
|
||||
"""
|
||||
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
url_or_filename, check_hash=False, progress=True
|
||||
)
|
||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||
else:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
if "model" in checkpoint.keys():
|
||||
state_dict = checkpoint["model"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
msg = self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
logging.info("Missing keys {}".format(msg.missing_keys))
|
||||
logging.info("load checkpoint from %s" % url_or_filename)
|
||||
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_type):
|
||||
"""
|
||||
Build a pretrained model from default configuration file, specified by model_type.
|
||||
|
||||
Args:
|
||||
- model_type (str): model type, specifying architecture and checkpoints.
|
||||
|
||||
Returns:
|
||||
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
||||
"""
|
||||
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
||||
model = cls.from_config(model_cfg)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, model_type):
|
||||
assert (
|
||||
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
||||
), "Unknown model type {}".format(model_type)
|
||||
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
||||
|
||||
def load_checkpoint_from_config(self, cfg, **kwargs):
|
||||
"""
|
||||
Load checkpoint as specified in the config file.
|
||||
|
||||
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
||||
When loading the pretrained model, each task-specific architecture may define their
|
||||
own load_from_pretrained() method.
|
||||
"""
|
||||
load_finetuned = cfg.get("load_finetuned", True)
|
||||
if load_finetuned:
|
||||
finetune_path = cfg.get("finetuned", None)
|
||||
assert (
|
||||
finetune_path is not None
|
||||
), "Found load_finetuned is True, but finetune_path is None."
|
||||
self.load_checkpoint(url_or_filename=finetune_path)
|
||||
else:
|
||||
# load pre-trained weights
|
||||
pretrain_path = cfg.get("pretrained", None)
|
||||
assert "Found load_finetuned is False, but pretrain_path is None."
|
||||
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
||||
|
||||
def before_evaluation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def show_n_params(self, return_str=True):
|
||||
tot = 0
|
||||
for p in self.parameters():
|
||||
w = 1
|
||||
for x in p.shape:
|
||||
w *= x
|
||||
tot += w
|
||||
if return_str:
|
||||
if tot >= 1e6:
|
||||
return "{:.1f}M".format(tot / 1e6)
|
||||
else:
|
||||
return "{:.1f}K".format(tot / 1e3)
|
||||
else:
|
||||
return tot
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""
|
||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_features(self, samples, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
|
||||
class SharedQueueMixin:
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
||||
# gather keys before updating queue
|
||||
image_feats = concat_all_gather(image_feat)
|
||||
text_feats = concat_all_gather(text_feat)
|
||||
|
||||
batch_size = image_feats.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.queue_size % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
||||
|
||||
if idxs is not None:
|
||||
idxs = concat_all_gather(idxs)
|
||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
||||
|
||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
|
||||
class MomentumDistilationMixin:
|
||||
@torch.no_grad()
|
||||
def copy_params(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data.copy_(param.data) # initialize
|
||||
param_m.requires_grad = False # not update by gradient
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data = param_m.data * self.momentum + param.data * (
|
||||
1.0 - self.momentum
|
||||
)
|
||||
|
||||
|
||||
class GatherLayer(torch.autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [
|
||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
torch.distributed.all_reduce(all_gradients)
|
||||
return all_gradients[torch.distributed.get_rank()]
|
||||
|
||||
|
||||
def all_gather_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = torch.distributed.get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
|
||||
# tensor_all = GatherLayer.apply(tensors)
|
||||
tensor_all = GatherLayer.apply(tensors)
|
||||
|
||||
return torch.cat(tensor_all, dim=0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
# if use distributed training
|
||||
if not is_dist_avail_and_initialized():
|
||||
return tensor
|
||||
|
||||
tensors_gather = [
|
||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def tile(x, dim, n_tile):
|
||||
init_dim = x.size(dim)
|
||||
repeat_idx = [1] * x.dim()
|
||||
repeat_idx[dim] = n_tile
|
||||
x = x.repeat(*(repeat_idx))
|
||||
order_index = torch.LongTensor(
|
||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||
)
|
||||
return torch.index_select(x, dim, order_index.to(x.device))
|
0
models/backbones/beit/__init__.py
Normal file
0
models/backbones/beit/__init__.py
Normal file
107
models/backbones/beit/builder.py
Normal file
107
models/backbones/beit/builder.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import logging
|
||||
import torch
|
||||
from models.utils import (interpolate_pos_relative_bias_beit,
|
||||
load_temp_embed_with_mismatch)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpolate_pos_embed_beit(state_dict, new_model):
|
||||
"""interpolate the positional embeddings.
|
||||
The spatial pe is relative and temporal pe is absolute.
|
||||
additional temporal pe is padded with 0.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The state_dict.
|
||||
new_model (nn.Module): The created model.
|
||||
|
||||
Returns: dict. The state_dict with updated positional embeddings.
|
||||
|
||||
"""
|
||||
state_dict = interpolate_pos_relative_bias_beit(
|
||||
state_dict_old=state_dict,
|
||||
state_dict_new=new_model.state_dict(),
|
||||
patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape,
|
||||
)
|
||||
# absolute temporal pos bias
|
||||
temporal_pe_key = "beit.embeddings.temporal_position_embeddings"
|
||||
if temporal_pe_key in state_dict:
|
||||
logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
|
||||
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
|
||||
temp_embed_old=state_dict[temporal_pe_key],
|
||||
temp_embed_new=new_model.state_dict()[temporal_pe_key],
|
||||
)
|
||||
return state_dict
|
||||
|
||||
def extract_beit_from_vindlu(vindlu_state_dict):
|
||||
beit_state_dict = {}
|
||||
beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k]
|
||||
for param_name in beit_param_names:
|
||||
new_name = param_name.replace('vision_encoder.', '')
|
||||
beit_state_dict[new_name] = vindlu_state_dict[param_name]
|
||||
|
||||
return beit_state_dict
|
||||
|
||||
def build_beit(model_config, image_res, checkpoint=False):
|
||||
"""build beit with configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The configs for beit.
|
||||
image_res (int): The image resolution.
|
||||
checkpoint (bool): Whether to enable gradient checkpointing.
|
||||
|
||||
Returns: nn.Module
|
||||
|
||||
"""
|
||||
from .st_beit import BeitConfig as config_cls
|
||||
from .st_beit import BeitModel as model_cls
|
||||
|
||||
|
||||
vindlu_state_dict = torch.load(model_config['vindlu_path'])['model']
|
||||
state_dict = extract_beit_from_vindlu(vindlu_state_dict)
|
||||
model_config = model_config['beit_config_json']
|
||||
|
||||
logger.info(
|
||||
f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}."
|
||||
)
|
||||
# BEiT uses average pooled tokens instead of [CLS] used by other models
|
||||
aux_kwargs = {"add_pooling_layer": True}
|
||||
# tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs)
|
||||
|
||||
|
||||
# tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs)
|
||||
# state_dict = tmp_model.state_dict()
|
||||
|
||||
# del tmp_model
|
||||
|
||||
logger.info(f"Init new model with new image size {image_res}, and load weights.")
|
||||
|
||||
# other_cfg = model_config.temporal_modeling
|
||||
other_cfg = {}
|
||||
|
||||
vit_config = config_cls.from_pretrained(
|
||||
model_config['pretrained'], image_size=image_res, **other_cfg
|
||||
)
|
||||
|
||||
# vit_config.update(model_config)
|
||||
|
||||
model = model_cls(config=vit_config, **aux_kwargs)
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# interpolate relative pos bias
|
||||
state_dict = interpolate_pos_relative_bias_beit(
|
||||
state_dict_old=state_dict,
|
||||
state_dict_new=model.state_dict(),
|
||||
patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
|
||||
)
|
||||
|
||||
# del prompt_bias_table
|
||||
for k in list(state_dict.keys()):
|
||||
if "prompt_bias_table" in k:
|
||||
del state_dict[k]
|
||||
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
logger.info(msg)
|
||||
return model
|
1752
models/backbones/beit/st_beit.py
Normal file
1752
models/backbones/beit/st_beit.py
Normal file
File diff suppressed because it is too large
Load diff
0
models/backbones/bert/__init__.py
Normal file
0
models/backbones/bert/__init__.py
Normal file
71
models/backbones/bert/builder.py
Normal file
71
models/backbones/bert/builder.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
|
||||
|
||||
|
||||
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
|
||||
"""build text encoder.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
pretrain (bool): Whether to do pretrain or finetuning.
|
||||
checkpoint (bool): whether to do gradient_checkpointing.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
bert_size = model_config['expert_size']
|
||||
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
|
||||
# bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||
bert_config.gradient_checkpointing = checkpoint
|
||||
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
|
||||
if expert_type=='modality':
|
||||
if modality_type == 'vis':
|
||||
bert_config.cross_attention_freq = 2
|
||||
else:
|
||||
bert_config.cross_attention_freq = -1
|
||||
else:
|
||||
bert_config.cross_attention_freq = 1
|
||||
|
||||
if pretrain:
|
||||
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
|
||||
f'bert-{bert_size}-uncased',
|
||||
config=bert_config,
|
||||
output_loading_info=True,
|
||||
)
|
||||
else:
|
||||
text_encoder, loading_info = BertModel.from_pretrained(
|
||||
f'bert-{bert_size}-uncased',
|
||||
config=bert_config,
|
||||
add_pooling_layer=True,
|
||||
output_loading_info=True,
|
||||
)
|
||||
|
||||
return text_encoder
|
||||
|
||||
|
||||
def build_bert_decoder(model_config, checkpoint):
|
||||
"""build text decoder the same as the multimodal encoder.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
pretrain (bool): Whether to do pretrain or finetuning.
|
||||
checkpoint (bool): whether to do gradient_checkpointing.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
|
||||
bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||
bert_config.gradient_checkpointing = checkpoint
|
||||
|
||||
bert_config.fusion_layer = 0
|
||||
bert_config.num_hidden_layers = (
|
||||
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
|
||||
)
|
||||
|
||||
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
|
||||
model_config.text_encoder.pretrained,
|
||||
config=bert_config,
|
||||
output_loading_info=True,
|
||||
)
|
||||
|
||||
return text_decoder
|
546
models/backbones/bert/tokenization_bert.py
Normal file
546
models/backbones/bert/tokenization_bert.py
Normal file
|
@ -0,0 +1,546 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Bert."""
|
||||
|
||||
|
||||
import collections
|
||||
import os
|
||||
import unicodedata
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
|
||||
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
|
||||
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
|
||||
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
||||
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
|
||||
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
|
||||
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
|
||||
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
|
||||
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
|
||||
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
||||
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
|
||||
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
|
||||
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
|
||||
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"bert-base-uncased": 512,
|
||||
"bert-large-uncased": 512,
|
||||
"bert-base-cased": 512,
|
||||
"bert-large-cased": 512,
|
||||
"bert-base-multilingual-uncased": 512,
|
||||
"bert-base-multilingual-cased": 512,
|
||||
"bert-base-chinese": 512,
|
||||
"bert-base-german-cased": 512,
|
||||
"bert-large-uncased-whole-word-masking": 512,
|
||||
"bert-large-cased-whole-word-masking": 512,
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
||||
"bert-base-cased-finetuned-mrpc": 512,
|
||||
"bert-base-german-dbmdz-cased": 512,
|
||||
"bert-base-german-dbmdz-uncased": 512,
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
||||
"wietsedv/bert-base-dutch-cased": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"bert-base-uncased": {"do_lower_case": True},
|
||||
"bert-large-uncased": {"do_lower_case": True},
|
||||
"bert-base-cased": {"do_lower_case": False},
|
||||
"bert-large-cased": {"do_lower_case": False},
|
||||
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
||||
"bert-base-multilingual-cased": {"do_lower_case": False},
|
||||
"bert-base-chinese": {"do_lower_case": False},
|
||||
"bert-base-german-cased": {"do_lower_case": False},
|
||||
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
||||
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
||||
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
||||
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
||||
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
||||
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
||||
}
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip("\n")
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Construct a BERT tokenizer. Based on WordPiece.
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||
Users should refer to this superclass for more information regarding those methods.
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
File containing the vocabulary.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to do basic tokenization before WordPiece.
|
||||
never_split (:obj:`Iterable`, `optional`):
|
||||
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||
:obj:`do_basic_tokenize=True`
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to tokenize Chinese characters.
|
||||
This should likely be deactivated for Japanese (see this `issue
|
||||
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||
strip_accents: (:obj:`bool`, `optional`):
|
||||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||
value for :obj:`lowercase` (as in the original BERT).
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=True,
|
||||
do_basic_tokenize=True,
|
||||
never_split=None,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
tokenize_chinese_chars=True,
|
||||
strip_accents=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
never_split=never_split,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
vocab_file)
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
do_lower_case=do_lower_case,
|
||||
never_split=never_split,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(
|
||||
vocab=self.vocab, unk_token=self.unk_token)
|
||||
|
||||
@property
|
||||
def do_lower_case(self):
|
||||
return self.basic_tokenizer.do_lower_case
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
else:
|
||||
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||
return out_string
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
- single sequence: ``[CLS] X ``
|
||||
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formatted with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
||||
pair mask has the following format:
|
||||
::
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
index = 0
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") +
|
||||
VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = (filename_prefix +
|
||||
"-" if filename_prefix else "") + save_directory
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(
|
||||
vocab_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""
|
||||
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
||||
Args:
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
never_split (:obj:`Iterable`, `optional`):
|
||||
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||
:obj:`do_basic_tokenize=True`
|
||||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to tokenize Chinese characters.
|
||||
This should likely be deactivated for Japanese (see this `issue
|
||||
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||
strip_accents: (:obj:`bool`, `optional`):
|
||||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||
value for :obj:`lowercase` (as in the original BERT).
|
||||
"""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
||||
if never_split is None:
|
||||
never_split = []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = set(never_split)
|
||||
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||
self.strip_accents = strip_accents
|
||||
|
||||
def tokenize(self, text, never_split=None):
|
||||
"""
|
||||
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
||||
WordPieceTokenizer.
|
||||
Args:
|
||||
**never_split**: (`optional`) list of str
|
||||
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
||||
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
|
||||
"""
|
||||
# union() returns a new set by concatenating the two sets.
|
||||
never_split = self.never_split.union(
|
||||
set(never_split)) if never_split else self.never_split
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
if self.tokenize_chinese_chars:
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if token not in never_split:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
if self.strip_accents is not False:
|
||||
token = self._run_strip_accents(token)
|
||||
elif self.strip_accents:
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
||||
tokenization using the given vocabulary.
|
||||
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
2160
models/backbones/bert/xbert.py
Normal file
2160
models/backbones/bert/xbert.py
Normal file
File diff suppressed because it is too large
Load diff
268
models/backbones/blip2.py
Executable file
268
models/backbones/blip2.py
Executable file
|
@ -0,0 +1,268 @@
|
|||
"""
|
||||
Copyright (c) 2023, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
# import .backbones.common.dist_utils as dist_utils
|
||||
# from minigpt4.common.dist_utils import download_cached_file
|
||||
# from minigpt4.common.utils import is_url
|
||||
# from minigpt4.common.logger import MetricLogger
|
||||
|
||||
from models.backbones.base_model import BaseModel
|
||||
from models.backbones.Qformer import BertConfig, BertLMHeadModel
|
||||
from models.backbones.eva_vit import create_eva_vit_g
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class Blip2Base(BaseModel):
|
||||
@classmethod
|
||||
def init_tokenizer(cls):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
return tokenizer
|
||||
|
||||
def maybe_autocast(self, dtype=torch.float16):
|
||||
# if on cpu, don't use autocast
|
||||
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||
enable_autocast = self.device != torch.device("cpu")
|
||||
|
||||
if enable_autocast:
|
||||
return torch.cuda.amp.autocast(dtype=dtype)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@classmethod
|
||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||
encoder_config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
encoder_config.add_cross_attention = True
|
||||
encoder_config.cross_attention_freq = cross_attention_freq
|
||||
encoder_config.query_length = num_query_token
|
||||
Qformer = BertLMHeadModel(config=encoder_config)
|
||||
query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||
)
|
||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||
return Qformer, query_tokens
|
||||
|
||||
@classmethod
|
||||
def init_vision_encoder(
|
||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
):
|
||||
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||
visual_encoder = create_eva_vit_g(
|
||||
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
)
|
||||
|
||||
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||
return visual_encoder, ln_vision
|
||||
|
||||
def load_from_pretrained(self, url_or_filename):
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
url_or_filename, check_hash=False, progress=True
|
||||
)
|
||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||
else:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
state_dict = checkpoint["model"]
|
||||
|
||||
msg = self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# logging.info("Missing keys {}".format(msg.missing_keys))
|
||||
logging.info("load checkpoint from %s" % url_or_filename)
|
||||
|
||||
return msg
|
||||
|
||||
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
||||
|
||||
vit_num_layers = self.visual_encoder.get_num_layer()
|
||||
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
|
||||
|
||||
parameter_group_names = {}
|
||||
parameter_group_vars = {}
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
if len(param.shape) == 1 or name.endswith(".bias"):
|
||||
group_name = "no_decay"
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
group_name = "decay"
|
||||
this_weight_decay = weight_decay
|
||||
if 'visual_encoder' in name:
|
||||
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
|
||||
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
|
||||
else:
|
||||
layer_id = None
|
||||
|
||||
if group_name not in parameter_group_names:
|
||||
if layer_id is not None:
|
||||
scale = lr_scales[layer_id]
|
||||
else:
|
||||
scale = 1
|
||||
parameter_group_names[group_name] = {
|
||||
"weight_decay": this_weight_decay,
|
||||
"params": [],
|
||||
"lr_scale": scale
|
||||
}
|
||||
parameter_group_vars[group_name] = {
|
||||
"weight_decay": this_weight_decay,
|
||||
"params": [],
|
||||
"lr_scale": scale
|
||||
}
|
||||
parameter_group_vars[group_name]["params"].append(param)
|
||||
parameter_group_names[group_name]["params"].append(name)
|
||||
# import json
|
||||
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
||||
optim_params = list(parameter_group_vars.values())
|
||||
return optim_params
|
||||
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
def compute_sim_matrix(model, data_loader, **kwargs):
|
||||
k_test = kwargs.pop("k_test")
|
||||
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Evaluation:"
|
||||
|
||||
logging.info("Computing features for evaluation...")
|
||||
start_time = time.time()
|
||||
|
||||
texts = data_loader.dataset.text
|
||||
num_text = len(texts)
|
||||
text_bs = 256
|
||||
text_ids = []
|
||||
text_embeds = []
|
||||
text_atts = []
|
||||
for i in range(0, num_text, text_bs):
|
||||
text = texts[i : min(num_text, i + text_bs)]
|
||||
text_input = model.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
text_feat = model.forward_text(text_input)
|
||||
text_embed = F.normalize(model.text_proj(text_feat))
|
||||
text_embeds.append(text_embed)
|
||||
text_ids.append(text_input.input_ids)
|
||||
text_atts.append(text_input.attention_mask)
|
||||
|
||||
text_embeds = torch.cat(text_embeds, dim=0)
|
||||
text_ids = torch.cat(text_ids, dim=0)
|
||||
text_atts = torch.cat(text_atts, dim=0)
|
||||
|
||||
vit_feats = []
|
||||
image_embeds = []
|
||||
for samples in data_loader:
|
||||
image = samples["image"]
|
||||
|
||||
image = image.to(model.device)
|
||||
image_feat, vit_feat = model.forward_image(image)
|
||||
image_embed = model.vision_proj(image_feat)
|
||||
image_embed = F.normalize(image_embed, dim=-1)
|
||||
|
||||
vit_feats.append(vit_feat.cpu())
|
||||
image_embeds.append(image_embed)
|
||||
|
||||
vit_feats = torch.cat(vit_feats, dim=0)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
sims_matrix = []
|
||||
for image_embed in image_embeds:
|
||||
sim_q2t = image_embed @ text_embeds.t()
|
||||
sim_i2t, _ = sim_q2t.max(0)
|
||||
sims_matrix.append(sim_i2t)
|
||||
sims_matrix = torch.stack(sims_matrix, dim=0)
|
||||
|
||||
score_matrix_i2t = torch.full(
|
||||
(len(data_loader.dataset.image), len(texts)), -100.0
|
||||
).to(model.device)
|
||||
|
||||
num_tasks = dist_utils.get_world_size()
|
||||
rank = dist_utils.get_rank()
|
||||
step = sims_matrix.size(0) // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(sims_matrix.size(0), start + step)
|
||||
|
||||
for i, sims in enumerate(
|
||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||
):
|
||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
|
||||
score = model.compute_itm(
|
||||
image_inputs=image_inputs,
|
||||
text_ids=text_ids[topk_idx],
|
||||
text_atts=text_atts[topk_idx],
|
||||
).float()
|
||||
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
|
||||
|
||||
sims_matrix = sims_matrix.t()
|
||||
score_matrix_t2i = torch.full(
|
||||
(len(texts), len(data_loader.dataset.image)), -100.0
|
||||
).to(model.device)
|
||||
|
||||
step = sims_matrix.size(0) // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(sims_matrix.size(0), start + step)
|
||||
|
||||
for i, sims in enumerate(
|
||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||
):
|
||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
|
||||
score = model.compute_itm(
|
||||
image_inputs=image_inputs,
|
||||
text_ids=text_ids[start + i].repeat(k_test, 1),
|
||||
text_atts=text_atts[start + i].repeat(k_test, 1),
|
||||
).float()
|
||||
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
|
||||
|
||||
if dist_utils.is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
torch.distributed.all_reduce(
|
||||
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logging.info("Evaluation time {}".format(total_time_str))
|
||||
|
||||
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
110
models/backbones/blip2_outputs.py
Executable file
110
models/backbones/blip2_outputs.py
Executable file
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import (
|
||||
ModelOutput,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipSimilarity(ModelOutput):
|
||||
sim_i2t: torch.FloatTensor = None
|
||||
sim_t2i: torch.FloatTensor = None
|
||||
|
||||
sim_i2t_m: Optional[torch.FloatTensor] = None
|
||||
sim_t2i_m: Optional[torch.FloatTensor] = None
|
||||
|
||||
sim_i2t_targets: Optional[torch.FloatTensor] = None
|
||||
sim_t2i_targets: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipIntermediateOutput(ModelOutput):
|
||||
"""
|
||||
Data class for intermediate outputs of BLIP models.
|
||||
|
||||
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
|
||||
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
|
||||
|
||||
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
|
||||
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
|
||||
|
||||
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
|
||||
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
|
||||
|
||||
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
|
||||
decoder_labels (torch.LongTensor): labels for the captioning loss.
|
||||
|
||||
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
|
||||
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
|
||||
|
||||
"""
|
||||
|
||||
# uni-modal features
|
||||
image_embeds: torch.FloatTensor = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
|
||||
image_embeds_m: Optional[torch.FloatTensor] = None
|
||||
text_embeds_m: Optional[torch.FloatTensor] = None
|
||||
|
||||
# intermediate outputs of multimodal encoder
|
||||
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||
|
||||
itm_logits: Optional[torch.FloatTensor] = None
|
||||
itm_labels: Optional[torch.LongTensor] = None
|
||||
|
||||
# intermediate outputs of multimodal decoder
|
||||
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
|
||||
decoder_labels: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipOutput(ModelOutput):
|
||||
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
|
||||
sims: Optional[BlipSimilarity] = None
|
||||
|
||||
intermediate_output: BlipIntermediateOutput = None
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_itc: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_itm: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_lm: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipOutputFeatures(ModelOutput):
|
||||
"""
|
||||
Data class of features from BlipFeatureExtractor.
|
||||
|
||||
Args:
|
||||
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
|
||||
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
|
||||
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
|
||||
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
|
||||
|
||||
The first embedding or feature is for the [CLS] token.
|
||||
|
||||
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
|
||||
"""
|
||||
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds_proj: Optional[torch.FloatTensor] = None
|
||||
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
text_embeds_proj: Optional[torch.FloatTensor] = None
|
||||
|
||||
multimodal_embeds: Optional[torch.FloatTensor] = None
|
83
models/backbones/clip_vision_encoder.py
Normal file
83
models/backbones/clip_vision_encoder.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
||||
|
||||
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
|
||||
super().__init__()
|
||||
|
||||
self.is_loaded = False
|
||||
|
||||
self.vision_encoder_name = encoder_name
|
||||
# self.select_layer = args.mm_vision_select_layer
|
||||
# self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||
self.select_layer = -1
|
||||
self.select_feature = "patch"
|
||||
if not delay_load:
|
||||
self.load_model()
|
||||
else:
|
||||
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
||||
|
||||
def load_model(self):
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
|
||||
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
|
||||
self.vision_encoder.requires_grad_(False)
|
||||
|
||||
self.is_loaded = True
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||
if self.select_feature == 'patch':
|
||||
image_features = image_features[:, :]
|
||||
elif self.select_feature == 'cls_patch':
|
||||
image_features = image_features
|
||||
else:
|
||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, images):
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||
# print("image feature shape", image_features.shape)
|
||||
# print(type(image_forward_outs))
|
||||
# print(type(image_forward_outs.shape))
|
||||
# image_features = image_forward_outs.to(images.dtype)
|
||||
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def dummy_feature(self):
|
||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vision_encoder.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_encoder.device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if self.is_loaded:
|
||||
return self.vision_encoder.config
|
||||
else:
|
||||
return self.cfg_only
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.config.hidden_size
|
||||
|
||||
@property
|
||||
def num_patches(self):
|
||||
return (self.config.image_size // self.config.patch_size) ** 2
|
141
models/backbones/encoder_decoder/builder.py
Normal file
141
models/backbones/encoder_decoder/builder.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
|
||||
import glog as logger
|
||||
import re
|
||||
import json
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
from .xflan_t5 import T5Config, T5ForConditionalGeneration
|
||||
from .xbart import BartConfig, BartForConditionalGeneration, BartEncoder, BartForCausalLM
|
||||
|
||||
|
||||
def build_encoder_decoder(model_config):
|
||||
"""build (encoder-) decoder model for answer generation.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
logger.info('[INFO] Loading Encoder Decoder [Type = {}]'.format(model_config['enc_dec_name']))
|
||||
|
||||
if model_config['enc_dec_family'] == 'flan_t5':
|
||||
config_cls = T5Config
|
||||
model_cls = T5ForConditionalGeneration
|
||||
elif model_config['enc_dec_family'] == 'bart':
|
||||
config_cls = BartConfig
|
||||
if model_config['use_decoder_only']:
|
||||
model_cls = BartForCausalLM
|
||||
else:
|
||||
model_cls = BartForConditionalGeneration
|
||||
else:
|
||||
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||
enc_dec_config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||
model_config['enc_dec_dim'] = enc_dec_config.d_model
|
||||
# enc_dec_config.encoder_layers = enc_dec_config.encoder_layers - model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||
enc_dec = model_cls.from_pretrained(
|
||||
model_config['enc_dec_name'],
|
||||
config=enc_dec_config
|
||||
)
|
||||
|
||||
# first_k = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||
# enc_dec.model.encoder.remove_first_k_layers(first_k)
|
||||
# get the last encoder layers
|
||||
# enc_dec.
|
||||
|
||||
|
||||
if model_config['use_lora_enc_dec']:
|
||||
# load the lora config
|
||||
with open(model_config['lora_config'], 'r') as f:
|
||||
lora_config = json.load(f)
|
||||
|
||||
# get the linear layer to perform LoRA on
|
||||
model_modules = str(enc_dec.modules)
|
||||
pattern = r'\((\w+)\): Linear'
|
||||
linear_layer_names = re.findall(pattern, model_modules)
|
||||
|
||||
names = []
|
||||
# Print the names of the Linear layers
|
||||
for name in linear_layer_names:
|
||||
names.append(name)
|
||||
target_modules = list(set(names))
|
||||
|
||||
lora_config['target_modules'] = target_modules
|
||||
|
||||
lora_config = LoraConfig(**lora_config)
|
||||
|
||||
enc_dec = get_peft_model(enc_dec, lora_config)
|
||||
|
||||
return enc_dec
|
||||
|
||||
|
||||
def build_encoder(model_config, expert_type, modality=None):
|
||||
"""build (encoder-) decoder model for answer generation.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
log_txt = '[INFO] Loading {} Expert'.format(expert_type)
|
||||
if modality is not None:
|
||||
log_txt += ' [Modality = {}]'.format(modality)
|
||||
log_txt += ' [Type = {}]'.format(model_config['enc_dec_name'])
|
||||
|
||||
logger.info(log_txt)
|
||||
|
||||
if model_config['enc_dec_family'] == 'flan_t5':
|
||||
config_cls = T5Config
|
||||
model_cls = T5ForConditionalGeneration
|
||||
elif model_config['enc_dec_family'] == 'bart':
|
||||
config_cls = BartConfig
|
||||
model_cls = BartEncoder
|
||||
else:
|
||||
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||
|
||||
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||
config.modality_expert_layers = model_config['num_layers_modality_expert_{}'.format(model_config['enc_dec_family'])]
|
||||
config.grounding_expert_layers = model_config['num_layers_grounding_expert_{}'.format(model_config['enc_dec_family'])]
|
||||
|
||||
model_config['enc_dec_dim'] = config.d_model
|
||||
|
||||
expert = model_cls.from_pretrained(
|
||||
model_config['enc_dec_name'],
|
||||
config=config,
|
||||
expert_type=expert_type,
|
||||
modality=modality
|
||||
)
|
||||
|
||||
if model_config['use_lora_expert']:
|
||||
# load the lora config
|
||||
with open(model_config['lora_config'], 'r') as f:
|
||||
lora_config = json.load(f)
|
||||
|
||||
# get the linear layer to perform LoRA on
|
||||
model_modules = str(expert.modules)
|
||||
pattern = r'\((\w+)\): Linear'
|
||||
linear_layer_names = re.findall(pattern, model_modules)
|
||||
|
||||
names = []
|
||||
# Print the names of the Linear layers
|
||||
for name in linear_layer_names:
|
||||
names.append(name)
|
||||
target_modules = list(set(names))
|
||||
|
||||
lora_config['target_modules'] = target_modules
|
||||
|
||||
lora_config = LoraConfig(**lora_config)
|
||||
|
||||
expert = get_peft_model(expert, lora_config)
|
||||
|
||||
# expert = model_cls(
|
||||
# config=config,
|
||||
# expert_type=expert_type,
|
||||
# modality=modality
|
||||
# )
|
||||
|
||||
return expert
|
||||
|
||||
|
65
models/backbones/encoder_decoder/builder_orig.py
Normal file
65
models/backbones/encoder_decoder/builder_orig.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from .xflan_t5 import T5Config, T5ForConditionalGeneration
|
||||
from .xbart_original import BartConfig, BartForConditionalGeneration, BartEncoder
|
||||
|
||||
import glog as logger
|
||||
|
||||
|
||||
def build_encoder_decoder(model_config):
|
||||
"""build (encoder-) decoder model for answer generation.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
logger.info('[INFO] Loading Encoder Decoder: {}'.format(model_config['enc_dec_name']))
|
||||
|
||||
if model_config['enc_dec_family'] == 'flan_t5':
|
||||
config_cls = T5Config
|
||||
model_cls = T5ForConditionalGeneration
|
||||
elif model_config['enc_dec_family'] == 'bart':
|
||||
config_cls = BartConfig
|
||||
model_cls = BartForConditionalGeneration
|
||||
else:
|
||||
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||
model_config['enc_dec_dim'] = config.d_model
|
||||
enc_dec = model_cls.from_pretrained(
|
||||
model_config['enc_dec_name'],
|
||||
config=config
|
||||
)
|
||||
|
||||
return enc_dec
|
||||
|
||||
|
||||
def build_encoder(model_config):
|
||||
"""build (encoder-) decoder model for answer generation.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
logger.info('[INFO] Loading Expert as Encoder of {}'.format(model_config['enc_dec_name']))
|
||||
|
||||
if model_config['enc_dec_family'] == 'flan_t5':
|
||||
config_cls = T5Config
|
||||
model_cls = T5ForConditionalGeneration
|
||||
elif model_config['enc_dec_family'] == 'bart':
|
||||
config_cls = BartConfig
|
||||
model_cls = BartEncoder
|
||||
else:
|
||||
raise ValueError('{} is not supported'.format(model_config['enc_dec_family']))
|
||||
|
||||
config = config_cls.from_pretrained(model_config['enc_dec_name'])
|
||||
model_config['enc_dec_dim'] = config.d_model
|
||||
config.encoder_layers = model_config['num_layers_modality_expert']
|
||||
|
||||
expert = model_cls.from_pretrained(
|
||||
model_config['enc_dec_name'],
|
||||
config=config
|
||||
)
|
||||
|
||||
return expert
|
19
models/backbones/encoder_decoder/outputs.py
Normal file
19
models/backbones/encoder_decoder/outputs.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqV2DialOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
2044
models/backbones/encoder_decoder/xbart.py
Normal file
2044
models/backbones/encoder_decoder/xbart.py
Normal file
File diff suppressed because it is too large
Load diff
1954
models/backbones/encoder_decoder/xbart_original.py
Normal file
1954
models/backbones/encoder_decoder/xbart_original.py
Normal file
File diff suppressed because it is too large
Load diff
2075
models/backbones/encoder_decoder/xflan_t5.py
Normal file
2075
models/backbones/encoder_decoder/xflan_t5.py
Normal file
File diff suppressed because it is too large
Load diff
455
models/backbones/eva_vit.py
Executable file
455
models/backbones/eva_vit.py
Executable file
|
@ -0,0 +1,455 @@
|
|||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
|
||||
from models.common.dist_utils import download_cached_file
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
||||
proj_drop=0., window_size=None, attn_head_dim=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||
window_size=None, attn_head_dim=None):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
||||
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
||||
for i in range(depth)])
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
def get_num_layer(self, var_name=""):
|
||||
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif var_name.startswith("rel_pos_bias"):
|
||||
return len(self.blocks) - 1
|
||||
elif var_name.startswith("blocks"):
|
||||
layer_id = int(var_name.split('.')[1])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return len(self.blocks)
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if 'pos_embed' in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
# depth = 37,
|
||||
num_heads=1408//88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
cached_file = download_cached_file(
|
||||
url, check_hash=False, progress=True
|
||||
)
|
||||
state_dict = torch.load(cached_file, map_location="cpu")
|
||||
interpolate_pos_embed(model,state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
# print(incompatible_keys)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
895
models/backbones/mini_gpt4_llama_v2.py
Executable file
895
models/backbones/mini_gpt4_llama_v2.py
Executable file
|
@ -0,0 +1,895 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast as autocast
|
||||
import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||
# from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
|
||||
# minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
|
||||
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from minigpt4.models import policies
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4_llama_v2")
|
||||
class MiniGPT4_llama_v2(Blip2Base):
|
||||
"""
|
||||
BLIP2 GPT-LLAMA model.
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
end_sym='\n',
|
||||
lora_r = 8,
|
||||
lora_target_modules = ["q_proj","v_proj"],
|
||||
lora_alpha=16,
|
||||
# lora_r = 16,
|
||||
# lora_target_modules = ["q_proj","v_proj","v_proj"],
|
||||
lora_dropout= 0.05,
|
||||
ckpt_path = "",
|
||||
system_prompt= False,
|
||||
chat_template=False,
|
||||
token_pooling=True,
|
||||
use_grad_checkpoint_llm=False,
|
||||
max_context_len=3800,
|
||||
remove_template = False,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
if "Mistral" in llama_model:
|
||||
from minigpt4.models.modeling_mistral import MistralForCausalLM as llm_model
|
||||
print("Mistral model")
|
||||
self.model_type = "Mistral"
|
||||
else:
|
||||
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM as llm_model
|
||||
print("Llama model")
|
||||
self.model_type = "Llama"
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
self.token_pooling = token_pooling
|
||||
self.remove_template = remove_template
|
||||
|
||||
print("token pooling", self.token_pooling)
|
||||
|
||||
|
||||
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
|
||||
self.max_context_len = max_context_len
|
||||
self.chat_template = chat_template
|
||||
|
||||
# print('Loading VIT')
|
||||
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
# )
|
||||
|
||||
if freeze_vit:
|
||||
# vit_precision="fp32"
|
||||
print("vit precision", vit_precision)
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print("freeze the vision encoder")
|
||||
|
||||
else:
|
||||
vit_precision="fp32"
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
|
||||
print("unfreeze the vision encoder")
|
||||
|
||||
print('Loading VIT Done')
|
||||
|
||||
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
|
||||
# assert False
|
||||
|
||||
print('Loading LLAMA')
|
||||
|
||||
|
||||
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model,use_fast=False) #
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
|
||||
|
||||
print("self.low_resource",self.low_resource)
|
||||
if self.low_resource:
|
||||
self.llama_model = llm_model.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
# torch_dtype = torch.bfloat16,
|
||||
load_in_8bit=True,
|
||||
# device_map = "balanced"
|
||||
# device_map="auto",
|
||||
device_map={'':torch.cuda.current_device()},
|
||||
# device_map={'':0}
|
||||
|
||||
)
|
||||
# bnb_config = BitsAndBytesConfig(
|
||||
# load_in_4bit=True,
|
||||
# bnb_4bit_use_double_quant=True,
|
||||
# bnb_4bit_quant_type="nf4",
|
||||
# bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
# )
|
||||
# self.llama_model = llm_model.from_pretrained(
|
||||
# llama_model,
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device_map={'':torch.cuda.current_device()},
|
||||
# quantization_config=bnb_config,
|
||||
# )
|
||||
else:
|
||||
self.llama_model = llm_model.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
|
||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||
|
||||
|
||||
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||
|
||||
# if ckpt_path:
|
||||
# print('load the llm under lora')
|
||||
# ckpt = torch.load(ckpt_path)
|
||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||
|
||||
|
||||
|
||||
self.llama_model.print_trainable_parameters()
|
||||
|
||||
if self.use_grad_checkpoint_llm:
|
||||
self.llama_model.gradient_checkpointing_enable()
|
||||
|
||||
# if not self.low_resource:
|
||||
# for name, param in self.llama_model.named_parameters():
|
||||
# if "embed_token" in name:
|
||||
# param.data = param.data.float()
|
||||
# param.requires_grad = True
|
||||
|
||||
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
|
||||
if self.token_pooling:
|
||||
self.llama_proj = nn.Linear(
|
||||
1408*4, self.llama_model.config.hidden_size
|
||||
)
|
||||
else:
|
||||
self.llama_proj = nn.Linear(
|
||||
1408, self.llama_model.config.hidden_size
|
||||
)
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
if len(image.shape) > 4:
|
||||
image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
|
||||
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
|
||||
bs, pn, hs = image_embeds.shape
|
||||
if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
|
||||
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
|
||||
|
||||
inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def get_context_emb(self, prompt, img_list):
|
||||
img_device = img_list[0].device
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
|
||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
# # truncate the length of tokens to the max context window
|
||||
# mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair]
|
||||
# mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1)
|
||||
# # check if the number of token in the second dimention is more than the max context window then truncate it
|
||||
# context_window=self.max_context_len-seg_embs[-1].shape[1]
|
||||
# if mixed_embs_without_instruction.shape[1] > context_window :
|
||||
# mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window]
|
||||
# mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1)
|
||||
# print("mixed_embs",mixed_embs.shape)
|
||||
|
||||
return mixed_embs
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||
if prompts is None or len(prompts) == 0:
|
||||
# prompts is not provided, just return the original image embedding
|
||||
return img_embeds, atts_img
|
||||
elif img_embeds is None:
|
||||
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
prompt_tokens = self.llama_tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||
atts_prompt = prompt_tokens.attention_mask
|
||||
return prompt_embeds, atts_prompt
|
||||
|
||||
else:
|
||||
# return the multi-modal embedding in right padding
|
||||
emb_lists = []
|
||||
|
||||
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||
pn = each_img_embed.shape[-2]
|
||||
if lengths is not None:
|
||||
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||
|
||||
p_segs = each_prompt.split('<ImageHere>')
|
||||
interleave_emb = []
|
||||
for idx, seg in enumerate(p_segs[:-1]):
|
||||
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
# print("p_embed device",p_tokens.input_ids.device)
|
||||
# print("p_tokens",img_embeds.device)
|
||||
# print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
|
||||
# print("model device",self.llama_model.get_device())
|
||||
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
||||
|
||||
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
|
||||
emb_lists.append(wrapped_emb)
|
||||
|
||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||
|
||||
# max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||
max_length = self.max_context_len
|
||||
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||
|
||||
for i, emb in enumerate(emb_lists):
|
||||
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||
wrapped_embs[i, :length] = emb[:, :length]
|
||||
wrapped_atts[i, :length] = 1
|
||||
|
||||
return wrapped_embs, wrapped_atts
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
"""
|
||||
Concatenate the batched input embedding and batched output embedding together.
|
||||
Both the input and the output embedding should be right padded.
|
||||
"""
|
||||
|
||||
input_lens = []
|
||||
cat_embs = []
|
||||
cat_atts = []
|
||||
|
||||
for i in range(input_embs.size(0)):
|
||||
input_len = input_atts[i].sum()
|
||||
input_lens.append(input_len)
|
||||
|
||||
cat_embs.append(
|
||||
torch.cat([
|
||||
input_embs[i][:input_len],
|
||||
output_embs[i],
|
||||
input_embs[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_atts.append(
|
||||
torch.cat([
|
||||
input_atts[i][:input_len],
|
||||
output_atts[i],
|
||||
input_atts[i][input_len:]
|
||||
])
|
||||
)
|
||||
# print('===================================')
|
||||
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
|
||||
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
|
||||
# print('check out emb: ', output_embs[i][:2])
|
||||
# print('check out pad emb: ', output_embs[i][-2:])
|
||||
# print('+++++++++++++++++++++++++++++++++++')
|
||||
#
|
||||
# print('check attn before: ', input_atts[i][:this_input_ones])
|
||||
# print('check attn after: ', input_atts[i][this_input_ones:])
|
||||
# print('check attn gt before: ', output_atts[i][:3])
|
||||
# print('check attn gt after: ', output_atts[i][-3:])
|
||||
|
||||
cat_embs = torch.stack(cat_embs)
|
||||
cat_atts = torch.stack(cat_atts)
|
||||
return cat_embs, cat_atts, input_lens
|
||||
|
||||
def get_conv_emb(self, conv_q, conv_a, conv_img):
|
||||
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||
|
||||
regress_embs_list = []
|
||||
targets_list = []
|
||||
|
||||
batch_size = len(conv_q)
|
||||
for batch_idx in range(batch_size):
|
||||
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||
assigned_imgs = conv_img[batch_idx]
|
||||
questions = [self.prompt_wrap(
|
||||
img_embeds=img,
|
||||
atts_img=None,
|
||||
prompts=[q],
|
||||
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
|
||||
q_embs = [emb for emb, _ in questions]
|
||||
|
||||
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
|
||||
cur_emb = []
|
||||
cur_target = []
|
||||
for i in range(len(questions)):
|
||||
cur_emb.append(q_embs[i])
|
||||
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
|
||||
|
||||
cur_emb.append(self.embed_tokens(answers[i].input_ids))
|
||||
cur_target.append(answers[i].input_ids)
|
||||
|
||||
cur_emb = torch.cat(cur_emb, dim=1)
|
||||
cur_target = torch.cat(cur_target, dim=1)
|
||||
|
||||
regress_embs_list.append(cur_emb)
|
||||
targets_list.append(cur_target)
|
||||
|
||||
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||
|
||||
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
|
||||
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
|
||||
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
cur_len = regress_embs_list[batch_idx].shape[1]
|
||||
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
|
||||
regress_attn[batch_idx, :cur_len] = 1
|
||||
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||
|
||||
return regress_embeds, regress_attn, targets
|
||||
|
||||
def preparing_embedding(self, samples):
|
||||
def remove_special_tokens(data):
|
||||
|
||||
# if "instruction_input" in data:
|
||||
data = [instruct.replace(" [caption]","") for instruct in data]
|
||||
data = [instruct.replace(" [vqa]","") for instruct in data]
|
||||
data = [instruct.replace(" [grounding]","") for instruct in data]
|
||||
data = [instruct.replace(" [identify]","") for instruct in data]
|
||||
data = [instruct.replace(" [refer]","") for instruct in data]
|
||||
return data
|
||||
|
||||
### prepare input tokens
|
||||
if 'image' in samples:
|
||||
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||
# print("img_embeds shape",img_embeds.shape)
|
||||
else:
|
||||
img_embeds = img_atts = None
|
||||
|
||||
if 'conv_q' in samples:
|
||||
# handeling conversation datasets
|
||||
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||
|
||||
connect_sym = samples['connect_sym'][0]
|
||||
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||
conv_img = assign_imgs(conv_q, img_embeds)
|
||||
|
||||
if self.chat_template:
|
||||
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||
|
||||
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
|
||||
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
||||
|
||||
else:
|
||||
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||
|
||||
# print("instruction before", instruction)
|
||||
if self.remove_template:
|
||||
instruction = remove_special_tokens(instruction)
|
||||
# print("instruction after", instruction)
|
||||
|
||||
if self.chat_template:
|
||||
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||
|
||||
if 'length' in samples:
|
||||
# the input is a image train (like videos)
|
||||
bsz, pn, hs = img_embeds.shape
|
||||
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||
else:
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||
|
||||
### prepare target tokens
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
text = [t + self.end_sym for t in samples["answer"]]
|
||||
|
||||
regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
|
||||
regress_token_ids = regress_tokens.input_ids
|
||||
regress_atts = regress_tokens.attention_mask
|
||||
part_targets = regress_token_ids.masked_fill(
|
||||
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||
|
||||
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||
|
||||
def forward(self, samples, reduction="mean"):
|
||||
# prepare the embedding to condition and the embedding to regress
|
||||
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||
self.preparing_embedding(samples)
|
||||
|
||||
# concat the embedding to condition and the embedding to regress
|
||||
inputs_embeds, attention_mask, input_lens = \
|
||||
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||
print("inputs_embeds shape",inputs_embeds.shape)
|
||||
print("cond_embeds shape",cond_embeds.shape)
|
||||
print("regress_embeds shape",regress_embeds.shape)
|
||||
# get bos token embedding
|
||||
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.embed_tokens(bos)
|
||||
bos_atts = attention_mask[:, :1]
|
||||
|
||||
# add bos token at the begining
|
||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||
# print length of instruction_input and answer words
|
||||
# for i in range (len(samples["instruction_input"])):
|
||||
# print("instruction_input length",len(samples["instruction_input"][i].split(" ")))
|
||||
# print("answer length",len(samples["answer"][i].split(" ")))
|
||||
# ensemble the final targets
|
||||
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||
dtype=torch.long).to(self.device).fill_(-100)
|
||||
for i, target in enumerate(part_targets):
|
||||
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||
print("targets shape",targets.shape)
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
reduction=reduction
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
images,
|
||||
texts,
|
||||
use_nucleus_sampling=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=20,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1,
|
||||
temperature=1,
|
||||
do_sample=False,
|
||||
stop_words_ids=[2],
|
||||
lengths=None,
|
||||
return_video_temporal_features=False,
|
||||
img_embeds=None,
|
||||
):
|
||||
'''
|
||||
function for generate test use
|
||||
'''
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||
if img_embeds is None:
|
||||
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||
else:
|
||||
# Use images features from the input(4,45,64,5632)
|
||||
img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
|
||||
img_embeds= img_embeds.to(self.device)
|
||||
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
||||
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
||||
|
||||
print("img_embeds shape",img_embeds.shape)
|
||||
if lengths is not None:
|
||||
image_lists = []
|
||||
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
||||
for idx, img_embed in enumerate(img_embeds):
|
||||
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
|
||||
else:
|
||||
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||
assert len(texts) == len(image_lists)
|
||||
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||
|
||||
batch_size = len(batch_embs)
|
||||
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||
emb_dim = batch_embs[0].shape[2]
|
||||
dtype = batch_embs[0].dtype
|
||||
device = batch_embs[0].device
|
||||
|
||||
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||
for i, emb in enumerate(batch_embs):
|
||||
emb_len = emb.shape[1]
|
||||
embs[i, -emb_len:] = emb[0]
|
||||
attn_mask[i, -emb_len:] = 1
|
||||
# print("inputs_embeds shape",embs.shape)
|
||||
# print("attention_mask shape",attn_mask.shape)
|
||||
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
|
||||
if self.model_type == "Llama":
|
||||
context_window = 3700
|
||||
else:
|
||||
context_window = 7500
|
||||
if embs.shape[1] > context_window:
|
||||
embs = embs[:, -context_window:]
|
||||
attn_mask = attn_mask[:, -context_window:]
|
||||
print("inputs_embeds shape",embs.shape)
|
||||
print("attention_mask shape",attn_mask.shape)
|
||||
with self.maybe_autocast():
|
||||
if return_video_temporal_features:
|
||||
last_hidden_state = self.llama_model(
|
||||
inputs_embeds=embs,
|
||||
attention_mask=attn_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-1]
|
||||
video_temporal_features = last_hidden_state.mean(dim=1)
|
||||
# normalize the temporal features using L2 norm
|
||||
# video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
attention_mask=attn_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
# stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
answers = []
|
||||
for output_token in outputs:
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||
output_texts = output_texts.replace("<s>", "")
|
||||
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||
answers.append(output_texts)
|
||||
if return_video_temporal_features:
|
||||
return answers, video_temporal_features
|
||||
else:
|
||||
return answers
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_text_only(
|
||||
self,
|
||||
images,
|
||||
seg_tokens,
|
||||
use_nucleus_sampling=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=20,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1,
|
||||
temperature=1,
|
||||
do_sample=False,
|
||||
stop_words_ids=[2],
|
||||
lengths=None,
|
||||
return_video_temporal_features=False,
|
||||
img_embeds=None,
|
||||
):
|
||||
'''
|
||||
function for generate test use
|
||||
'''
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||
|
||||
# seg_tokens=[]
|
||||
# for i, text in enumerate(texts):
|
||||
# seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids)
|
||||
|
||||
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
|
||||
|
||||
# seg_embs = torch.cat(seg_embs, dim=1)
|
||||
# print("seg_embs shape",seg_embs.shape)
|
||||
# batch_embs=[seg_embs]
|
||||
batch_size = len(batch_embs)
|
||||
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||
emb_dim = batch_embs[0].shape[2]
|
||||
dtype = batch_embs[0].dtype
|
||||
device = batch_embs[0].device
|
||||
|
||||
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||
for i, emb in enumerate(batch_embs):
|
||||
emb_len = emb.shape[1]
|
||||
embs[i, -emb_len:] = emb[0]
|
||||
attn_mask[i, -emb_len:] = 1
|
||||
|
||||
|
||||
print("inputs_embeds shape",embs.shape)
|
||||
print("attention_mask shape",attn_mask.shape)
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
attention_mask=attn_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
# stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
answers = []
|
||||
for output_token in outputs:
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||
output_texts = output_texts.replace("<s>", "")
|
||||
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||
answers.append(output_texts)
|
||||
return answers
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def multi_select(self, images, texts, answers, num_cand=None):
|
||||
all_losses = []
|
||||
for answer in answers:
|
||||
choice_samples = {
|
||||
'image': images,
|
||||
'instruction_input': texts,
|
||||
'answer': answer
|
||||
}
|
||||
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||
all_losses.append(loss)
|
||||
torch.cuda.empty_cache()
|
||||
all_losses = torch.cat(all_losses, dim=-1)
|
||||
if num_cand is not None:
|
||||
for i in range(all_losses.shape[0]):
|
||||
all_losses[i, num_cand[i]:] = 9999
|
||||
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||
return output_class_ranks.tolist()
|
||||
|
||||
def predict_answers(
|
||||
self,
|
||||
samples,
|
||||
num_beams=5,
|
||||
inference_method="generate",
|
||||
max_len=10,
|
||||
min_len=1,
|
||||
num_ans_candidates=128,
|
||||
answer_list=None,
|
||||
prompt="",
|
||||
length_penalty=0,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
function for open-ended VQA
|
||||
'''
|
||||
images = samples["image"].cuda()
|
||||
texts = samples["instruction_input"]
|
||||
|
||||
output_text = self.generate(
|
||||
images=images,
|
||||
texts=texts,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_len,
|
||||
min_length=min_len,
|
||||
length_penalty=length_penalty
|
||||
)
|
||||
|
||||
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
||||
output_text = self._lemmatize(output_text)
|
||||
|
||||
return output_text
|
||||
|
||||
def predict_class(
|
||||
self,
|
||||
samples,
|
||||
num_beams=5,
|
||||
inference_method="generate",
|
||||
max_len=10,
|
||||
min_len=1,
|
||||
num_ans_candidates=5,
|
||||
answer_list=None,
|
||||
prompt="",
|
||||
length_penalty=0,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
function for multi-choice VQA
|
||||
'''
|
||||
|
||||
image = samples["image"].cuda()
|
||||
instruction = samples['instruction_input']
|
||||
answers = samples["choices"]
|
||||
num_cand = samples["num_choices"]
|
||||
|
||||
ranks = self.multi_select(image, instruction, answers, num_cand)
|
||||
|
||||
pred_ans = []
|
||||
for i, rank in enumerate(ranks):
|
||||
pred = answers[rank[0]][i]
|
||||
pred_ans.append(pred)
|
||||
return pred_ans
|
||||
|
||||
def embed_tokens(self, token_ids):
|
||||
try:
|
||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||
except AttributeError:
|
||||
embeds = self.llama_model.model.embed_tokens(token_ids)
|
||||
|
||||
return embeds
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||
img_size = cfg.get("image_size")
|
||||
num_query_token = cfg.get("num_query_token")
|
||||
llama_model = cfg.get("llama_model")
|
||||
|
||||
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||
low_resource = cfg.get("low_resource", False)
|
||||
|
||||
prompt_path = cfg.get("prompt_path", "")
|
||||
prompt_template = cfg.get("prompt_template", "")
|
||||
max_txt_len = cfg.get("max_txt_len", 300)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
lora_r = cfg.get("lora_r",64)
|
||||
lora_alpha = cfg.get("lora_alpha",16)
|
||||
chat_template = cfg.get("chat_template",False)
|
||||
system_prompt = cfg.get("system_prompt", False)
|
||||
token_pooling = cfg.get("token_pooling",True)
|
||||
|
||||
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||
max_context_len = cfg.get("max_context_len", 3800)
|
||||
remove_template = cfg.get("remove_template", False)
|
||||
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
llama_model=llama_model,
|
||||
prompt_path=prompt_path,
|
||||
prompt_template=prompt_template,
|
||||
max_txt_len=max_txt_len,
|
||||
low_resource=low_resource,
|
||||
end_sym=end_sym,
|
||||
lora_r = lora_r,
|
||||
lora_alpha = lora_alpha,
|
||||
chat_template = chat_template,
|
||||
system_prompt = system_prompt,
|
||||
token_pooling = token_pooling,
|
||||
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||
max_context_len=max_context_len,
|
||||
remove_template = remove_template
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
if ckpt_path:
|
||||
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def assign_imgs(batched_instruct_list, batched_img_embeds):
|
||||
'''this function is used when the data is interleaved.
|
||||
the interlevaed data is separated, and this function assign
|
||||
corresponding image embeddings to each segment'''
|
||||
if len(batched_img_embeds.shape) == 3:
|
||||
batched_img_embeds = batched_img_embeds[:, None]
|
||||
|
||||
batched_assigned = []
|
||||
|
||||
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
|
||||
img_idx = 0
|
||||
assigned_img = []
|
||||
n_assigned = []
|
||||
for instruct in instruct_list:
|
||||
n_img = instruct.count('<ImageHere>')
|
||||
if n_img > 0: # this instruction include images.
|
||||
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
|
||||
img_idx += n_img
|
||||
n_assigned.append(n_img)
|
||||
else: # this instruction doesn't include images
|
||||
assigned_img.append(None)
|
||||
n_assigned.append(None)
|
||||
batched_assigned.append(assigned_img)
|
||||
|
||||
return batched_assigned
|
709
models/backbones/mini_gpt4v.py
Executable file
709
models/backbones/mini_gpt4v.py
Executable file
|
@ -0,0 +1,709 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast as autocast
|
||||
import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
|
||||
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||
|
||||
from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
prepare_model_for_kbit_training
|
||||
)
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from minigpt4.models import policies
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4v")
|
||||
class MiniGPT4v(Blip2Base):
|
||||
"""
|
||||
BLIP2 GPT-LLAMA model.
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
end_sym='\n',
|
||||
lora_r = 8,
|
||||
lora_target_modules = ["q_proj","v_proj"],
|
||||
lora_alpha=16,
|
||||
# lora_r = 16,
|
||||
# lora_target_modules = ["q_proj","v_proj","v_proj"],
|
||||
lora_dropout= 0.05,
|
||||
ckpt_path = "",
|
||||
system_prompt= False,
|
||||
chat_template=False,
|
||||
token_pooling=True,
|
||||
use_grad_checkpoint_llm=False,
|
||||
max_context_len=3800,
|
||||
remove_template = False,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
self.token_pooling = token_pooling
|
||||
self.remove_template = remove_template
|
||||
|
||||
print("token pooling", self.token_pooling)
|
||||
|
||||
|
||||
self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
|
||||
self.max_context_len = max_context_len
|
||||
self.chat_template = chat_template
|
||||
|
||||
# print('Loading VIT')
|
||||
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
# )
|
||||
|
||||
|
||||
print("vit precision", vit_precision)
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print("freeze the vision encoder")
|
||||
|
||||
|
||||
print('Loading VIT Done')
|
||||
|
||||
# print("visual encoder shape", self.visual_encoder.pos_embed.shape)
|
||||
# assert False
|
||||
|
||||
print('Loading LLAMA')
|
||||
|
||||
|
||||
self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
if 'CodeLlama' in llama_model:
|
||||
self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
else:
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": 0}
|
||||
)
|
||||
|
||||
# self.llama_model.gradient_checkpointing_enable()
|
||||
self.llama_model = prepare_model_for_kbit_training(self.llama_model)
|
||||
|
||||
# self.llama_model.print_trainable_parameters()
|
||||
|
||||
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.merge_n = 3
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
1408 * self.merge_n**2, self.llama_model.config.hidden_size
|
||||
)
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
if len(image.shape) > 4:
|
||||
image = image.reshape(-1, *image.shape[-3:])
|
||||
|
||||
bs, ch, w, h = image.shape
|
||||
assert w % 224 == 0
|
||||
bw = w // 224
|
||||
assert h % 224 == 0
|
||||
bh = h // 224
|
||||
image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224
|
||||
image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)
|
||||
|
||||
image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
|
||||
image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs
|
||||
image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])
|
||||
|
||||
bs, pn, hs = image_embeds.shape
|
||||
|
||||
image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))
|
||||
|
||||
inputs_llama = self.llama_proj(image_embeds)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def get_context_emb(self, prompt, img_list):
|
||||
img_device = img_list[0].device
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
|
||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||
if prompts is None or len(prompts) == 0:
|
||||
# prompts is not provided, just return the original image embedding
|
||||
return img_embeds, atts_img
|
||||
elif img_embeds is None:
|
||||
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
prompt_tokens = self.llama_tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||
atts_prompt = prompt_tokens.attention_mask
|
||||
return prompt_embeds, atts_prompt
|
||||
|
||||
else:
|
||||
# return the multi-modal embedding in right padding
|
||||
emb_lists = []
|
||||
|
||||
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||
pn = each_img_embed.shape[-2]
|
||||
if lengths is not None:
|
||||
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||
|
||||
p_segs = each_prompt.split('<ImageHere>')
|
||||
interleave_emb = []
|
||||
for idx, seg in enumerate(p_segs[:-1]):
|
||||
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
||||
|
||||
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||
p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
|
||||
emb_lists.append(wrapped_emb)
|
||||
|
||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||
|
||||
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||
|
||||
for i, emb in enumerate(emb_lists):
|
||||
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||
wrapped_embs[i, :length] = emb[:, :length]
|
||||
wrapped_atts[i, :length] = 1
|
||||
|
||||
return wrapped_embs, wrapped_atts
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
"""
|
||||
Concatenate the batched input embedding and batched output embedding together.
|
||||
Both the input and the output embedding should be right padded.
|
||||
"""
|
||||
|
||||
input_lens = []
|
||||
cat_embs = []
|
||||
cat_atts = []
|
||||
|
||||
for i in range(input_embs.size(0)):
|
||||
input_len = input_atts[i].sum()
|
||||
input_lens.append(input_len)
|
||||
|
||||
cat_embs.append(
|
||||
torch.cat([
|
||||
input_embs[i][:input_len],
|
||||
output_embs[i],
|
||||
input_embs[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_atts.append(
|
||||
torch.cat([
|
||||
input_atts[i][:input_len],
|
||||
output_atts[i],
|
||||
input_atts[i][input_len:]
|
||||
])
|
||||
)
|
||||
# print('===================================')
|
||||
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
|
||||
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
|
||||
# print('check out emb: ', output_embs[i][:2])
|
||||
# print('check out pad emb: ', output_embs[i][-2:])
|
||||
# print('+++++++++++++++++++++++++++++++++++')
|
||||
#
|
||||
# print('check attn before: ', input_atts[i][:this_input_ones])
|
||||
# print('check attn after: ', input_atts[i][this_input_ones:])
|
||||
# print('check attn gt before: ', output_atts[i][:3])
|
||||
# print('check attn gt after: ', output_atts[i][-3:])
|
||||
|
||||
cat_embs = torch.stack(cat_embs)
|
||||
cat_atts = torch.stack(cat_atts)
|
||||
return cat_embs, cat_atts, input_lens
|
||||
|
||||
def get_conv_emb(self, conv_q, conv_a, conv_img):
|
||||
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||
|
||||
regress_embs_list = []
|
||||
targets_list = []
|
||||
|
||||
batch_size = len(conv_q)
|
||||
for batch_idx in range(batch_size):
|
||||
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||
assigned_imgs = conv_img[batch_idx]
|
||||
questions = [self.prompt_wrap(
|
||||
img_embeds=img,
|
||||
atts_img=None,
|
||||
prompts=[q],
|
||||
lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
|
||||
q_embs = [emb for emb, _ in questions]
|
||||
|
||||
answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
|
||||
cur_emb = []
|
||||
cur_target = []
|
||||
for i in range(len(questions)):
|
||||
cur_emb.append(q_embs[i])
|
||||
cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
|
||||
|
||||
cur_emb.append(self.embed_tokens(answers[i].input_ids))
|
||||
cur_target.append(answers[i].input_ids)
|
||||
|
||||
cur_emb = torch.cat(cur_emb, dim=1)
|
||||
cur_target = torch.cat(cur_target, dim=1)
|
||||
|
||||
regress_embs_list.append(cur_emb)
|
||||
targets_list.append(cur_target)
|
||||
|
||||
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||
|
||||
regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
|
||||
regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
|
||||
targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
cur_len = regress_embs_list[batch_idx].shape[1]
|
||||
regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
|
||||
regress_attn[batch_idx, :cur_len] = 1
|
||||
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||
|
||||
return regress_embeds, regress_attn, targets
|
||||
|
||||
def preparing_embedding(self, samples):
|
||||
def remove_special_tokens(data):
|
||||
|
||||
# if "instruction_input" in data:
|
||||
data = [instruct.replace(" [caption]","") for instruct in data]
|
||||
data = [instruct.replace(" [vqa]","") for instruct in data]
|
||||
data = [instruct.replace(" [grounding]","") for instruct in data]
|
||||
data = [instruct.replace(" [identify]","") for instruct in data]
|
||||
data = [instruct.replace(" [refer]","") for instruct in data]
|
||||
return data
|
||||
|
||||
### prepare input tokens
|
||||
if 'image' in samples:
|
||||
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||
else:
|
||||
img_embeds = img_atts = None
|
||||
|
||||
if 'conv_q' in samples:
|
||||
# handeling conversation datasets
|
||||
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||
|
||||
connect_sym = samples['connect_sym'][0]
|
||||
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||
conv_img = assign_imgs(conv_q, img_embeds)
|
||||
|
||||
if self.chat_template:
|
||||
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||
|
||||
regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
|
||||
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
||||
|
||||
else:
|
||||
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||
|
||||
# print("instruction before", instruction)
|
||||
if self.remove_template:
|
||||
instruction = remove_special_tokens(instruction)
|
||||
# print("instruction after", instruction)
|
||||
|
||||
if self.chat_template:
|
||||
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||
|
||||
if 'length' in samples:
|
||||
# the input is a image train (like videos)
|
||||
bsz, pn, hs = img_embeds.shape
|
||||
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||
else:
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||
|
||||
### prepare target tokens
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
text = [t + self.end_sym for t in samples["answer"]]
|
||||
|
||||
regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
|
||||
regress_token_ids = regress_tokens.input_ids
|
||||
regress_atts = regress_tokens.attention_mask
|
||||
part_targets = regress_token_ids.masked_fill(
|
||||
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||
|
||||
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||
|
||||
def forward(self, samples, reduction="mean"):
|
||||
# prepare the embedding to condition and the embedding to regress
|
||||
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||
self.preparing_embedding(samples)
|
||||
|
||||
# concat the embedding to condition and the embedding to regress
|
||||
inputs_embeds, attention_mask, input_lens = \
|
||||
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||
|
||||
# get bos token embedding
|
||||
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.embed_tokens(bos)
|
||||
bos_atts = attention_mask[:, :1]
|
||||
|
||||
# add bos token at the begining
|
||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||
|
||||
# ensemble the final targets
|
||||
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||
dtype=torch.long).to(self.device).fill_(-100)
|
||||
for i, target in enumerate(part_targets):
|
||||
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
reduction=reduction
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
images,
|
||||
texts,
|
||||
use_nucleus_sampling=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=20,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1,
|
||||
length_penalty=1,
|
||||
temperature=1,
|
||||
do_sample=False,
|
||||
stop_words_ids=[2],
|
||||
lengths=None,
|
||||
):
|
||||
'''
|
||||
function for generate test use
|
||||
'''
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||
|
||||
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||
if lengths is not None:
|
||||
image_lists = []
|
||||
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
||||
for idx, img_embed in enumerate(img_embeds):
|
||||
image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
|
||||
else:
|
||||
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||
assert len(texts) == len(image_lists)
|
||||
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||
|
||||
batch_size = len(batch_embs)
|
||||
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||
emb_dim = batch_embs[0].shape[2]
|
||||
dtype = batch_embs[0].dtype
|
||||
device = batch_embs[0].device
|
||||
|
||||
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||
for i, emb in enumerate(batch_embs):
|
||||
emb_len = emb.shape[1]
|
||||
embs[i, -emb_len:] = emb[0]
|
||||
attn_mask[i, -emb_len:] = 1
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
attention_mask=attn_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
# stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
answers = []
|
||||
for output_token in outputs:
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||
output_texts = output_texts.replace("<s>", "")
|
||||
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||
answers.append(output_texts)
|
||||
|
||||
return answers
|
||||
|
||||
@torch.no_grad()
|
||||
def multi_select(self, images, texts, answers, num_cand=None):
|
||||
all_losses = []
|
||||
for answer in answers:
|
||||
choice_samples = {
|
||||
'image': images,
|
||||
'instruction_input': texts,
|
||||
'answer': answer
|
||||
}
|
||||
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||
all_losses.append(loss)
|
||||
torch.cuda.empty_cache()
|
||||
all_losses = torch.cat(all_losses, dim=-1)
|
||||
if num_cand is not None:
|
||||
for i in range(all_losses.shape[0]):
|
||||
all_losses[i, num_cand[i]:] = 9999
|
||||
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||
return output_class_ranks.tolist()
|
||||
|
||||
def predict_answers(
|
||||
self,
|
||||
samples,
|
||||
num_beams=5,
|
||||
inference_method="generate",
|
||||
max_len=10,
|
||||
min_len=1,
|
||||
num_ans_candidates=128,
|
||||
answer_list=None,
|
||||
prompt="",
|
||||
length_penalty=0,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
function for open-ended VQA
|
||||
'''
|
||||
images = samples["image"].cuda()
|
||||
texts = samples["instruction_input"]
|
||||
|
||||
output_text = self.generate(
|
||||
images=images,
|
||||
texts=texts,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_len,
|
||||
min_length=min_len,
|
||||
length_penalty=length_penalty
|
||||
)
|
||||
|
||||
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
|
||||
output_text = self._lemmatize(output_text)
|
||||
|
||||
return output_text
|
||||
|
||||
def predict_class(
|
||||
self,
|
||||
samples,
|
||||
num_beams=5,
|
||||
inference_method="generate",
|
||||
max_len=10,
|
||||
min_len=1,
|
||||
num_ans_candidates=5,
|
||||
answer_list=None,
|
||||
prompt="",
|
||||
length_penalty=0,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
function for multi-choice VQA
|
||||
'''
|
||||
|
||||
image = samples["image"].cuda()
|
||||
instruction = samples['instruction_input']
|
||||
answers = samples["choices"]
|
||||
num_cand = samples["num_choices"]
|
||||
|
||||
ranks = self.multi_select(image, instruction, answers, num_cand)
|
||||
|
||||
pred_ans = []
|
||||
for i, rank in enumerate(ranks):
|
||||
pred = answers[rank[0]][i]
|
||||
pred_ans.append(pred)
|
||||
return pred_ans
|
||||
|
||||
def embed_tokens(self, token_ids):
|
||||
try:
|
||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||
except AttributeError:
|
||||
embeds = self.llama_model.model.embed_tokens(token_ids)
|
||||
|
||||
return embeds
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||
img_size = cfg.get("image_size")
|
||||
num_query_token = cfg.get("num_query_token")
|
||||
llama_model = cfg.get("llama_model")
|
||||
|
||||
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||
low_resource = cfg.get("low_resource", False)
|
||||
|
||||
prompt_path = cfg.get("prompt_path", "")
|
||||
prompt_template = cfg.get("prompt_template", "")
|
||||
max_txt_len = cfg.get("max_txt_len", 300)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
lora_r = cfg.get("lora_r",64)
|
||||
lora_alpha = cfg.get("lora_alpha",16)
|
||||
chat_template = cfg.get("chat_template",False)
|
||||
system_prompt = cfg.get("system_prompt", False)
|
||||
token_pooling = cfg.get("token_pooling",True)
|
||||
|
||||
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||
max_context_len = cfg.get("max_context_len", 3800)
|
||||
remove_template = cfg.get("remove_template", False)
|
||||
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
llama_model=llama_model,
|
||||
prompt_path=prompt_path,
|
||||
prompt_template=prompt_template,
|
||||
max_txt_len=max_txt_len,
|
||||
low_resource=low_resource,
|
||||
end_sym=end_sym,
|
||||
lora_r = lora_r,
|
||||
lora_alpha = lora_alpha,
|
||||
chat_template = chat_template,
|
||||
system_prompt = system_prompt,
|
||||
token_pooling = token_pooling,
|
||||
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||
max_context_len=max_context_len,
|
||||
remove_template = remove_template
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
if ckpt_path:
|
||||
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def assign_imgs(batched_instruct_list, batched_img_embeds):
|
||||
'''this function is used when the data is interleaved.
|
||||
the interlevaed data is separated, and this function assign
|
||||
corresponding image embeddings to each segment'''
|
||||
if len(batched_img_embeds.shape) == 3:
|
||||
batched_img_embeds = batched_img_embeds[:, None]
|
||||
|
||||
batched_assigned = []
|
||||
|
||||
for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
|
||||
img_idx = 0
|
||||
assigned_img = []
|
||||
n_assigned = []
|
||||
for instruct in instruct_list:
|
||||
n_img = instruct.count('<ImageHere>')
|
||||
if n_img > 0: # this instruction include images.
|
||||
assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
|
||||
img_idx += n_img
|
||||
n_assigned.append(n_img)
|
||||
else: # this instruction doesn't include images
|
||||
assigned_img.append(None)
|
||||
n_assigned.append(None)
|
||||
batched_assigned.append(assigned_img)
|
||||
|
||||
return batched_assigned
|
25
models/backbones/mistral.py
Normal file
25
models/backbones/mistral.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = "cuda" # the device to load the model onto
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is your favourite condiment?"},
|
||||
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
|
||||
{"role": "user", "content": "Do you have mayonnaise recipes?"}
|
||||
]
|
||||
p="Well, I'm quite partial to a good squeeze of fresh lemon juice."
|
||||
encoded_input = tokenizer(p, return_tensors='pt')
|
||||
embeds = model.model.embed_tokens(encoded_input.input_ids)
|
||||
print(embeds.shape)
|
||||
|
||||
|
||||
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
||||
model_inputs = encodeds.to(device)
|
||||
model.to(device)
|
||||
|
||||
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
||||
decoded = tokenizer.batch_decode(generated_ids)
|
||||
print(decoded[0])
|
112
models/backbones/modeling_llama_v2.py
Normal file
112
models/backbones/modeling_llama_v2.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
reduction: Optional[str] = "mean",
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction=reduction)
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
if reduction == "none":
|
||||
loss = loss.view(logits.size(0), -1).mean(1)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
112
models/backbones/modeling_llama_v3.py
Normal file
112
models/backbones/modeling_llama_v3.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
reduction: Optional[str] = "mean",
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction=reduction)
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
if reduction == "none":
|
||||
loss = loss.view(logits.size(0), -1).mean(1)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
1388
models/backbones/modeling_mistral.py
Normal file
1388
models/backbones/modeling_mistral.py
Normal file
File diff suppressed because it is too large
Load diff
287
models/backbones/moes.py
Normal file
287
models/backbones/moes.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
if mask is not None:
|
||||
# if mask.dim() != x.dim():
|
||||
# expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
|
||||
# else:
|
||||
# expanded_mask = mask
|
||||
mask = mask.bool()
|
||||
attn = attn.masked_fill(~mask, float("-inf"))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
expert_type,
|
||||
use_sep_spatial_temp_experts=True,
|
||||
has_hist=False,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=LlamaRMSNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_hist = has_hist
|
||||
self.use_sep_spatial_temp_experts = use_sep_spatial_temp_experts
|
||||
self.norm_att = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
if expert_type == 'modalities':
|
||||
# EXPERT CONSTRUCTION
|
||||
if use_sep_spatial_temp_experts:
|
||||
# Spatial expert
|
||||
self.norm_spatial = norm_layer(dim)
|
||||
self.mlp_spatial = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Temporal expert
|
||||
self.norm_temp = norm_layer(dim)
|
||||
self.mlp_temp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Vis expert
|
||||
self.norm_vis = norm_layer(dim)
|
||||
self.mlp_vis = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# caption expert
|
||||
self.norm_cap = norm_layer(dim)
|
||||
self.mlp_cap = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# history expert
|
||||
if has_hist:
|
||||
self.norm_hist = norm_layer(dim)
|
||||
self.mlp_hist = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
elif expert_type == 'fusion':
|
||||
# Fusion expert
|
||||
self.norm_fusion = norm_layer(dim)
|
||||
self.mlp_fusion = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def forward(self, x, vis_feat_len, cap_feat_len, expert_flag, hist_feat_len=None, is_vid=False, mask=None, only_text=False, expert_permutation=None):
|
||||
|
||||
if self.has_hist:
|
||||
assert hist_feat_len is not None
|
||||
|
||||
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
|
||||
x = x + self.drop_path(x_shortcut)
|
||||
len_init = x.size(1)
|
||||
# bs, h_dim = x.size(0), x.size(-1)
|
||||
# device = x.device
|
||||
# if only_text:
|
||||
# # end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
|
||||
# # x = x[:, special_toks_indices['<caption>']: end_idx_caption, :]
|
||||
# x = x + self.drop_path(self.mlp_cap(self.norm_cap(x)))
|
||||
|
||||
if expert_flag == 'modalities':
|
||||
if self.use_sep_spatial_temp_experts:
|
||||
x_spatial = x[:, :vis_feat_len]
|
||||
if expert_permutation is not None:
|
||||
if expert_permutation['spatial'] == 'temporal':
|
||||
x_spatial = x_spatial + self.drop_path(self.mlp_temp(self.norm_temp(x_spatial)))
|
||||
elif expert_permutation['spatial'] == 'caption':
|
||||
x_spatial = x_spatial + self.drop_path(self.mlp_cap(self.norm_cap(x_spatial)))
|
||||
elif expert_permutation['spatial'] == 'history':
|
||||
x_spatial = x_spatial + self.drop_path(self.mlp_hist(self.norm_hist(x_spatial)))
|
||||
elif expert_permutation['spatial'] == 'spatial':
|
||||
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
|
||||
x_vis = x_spatial
|
||||
|
||||
else:
|
||||
x_spatial = x_spatial + self.drop_path(self.mlp_spatial(self.norm_spatial(x_spatial)))
|
||||
x_vis = x_spatial
|
||||
|
||||
if is_vid:
|
||||
x_temporal = x[:, vis_feat_len:2*vis_feat_len]
|
||||
if expert_permutation is not None:
|
||||
if expert_permutation['temporal'] == 'spatial':
|
||||
x_temporal = x_temporal + self.drop_path(self.mlp_spatial(self.norm_spatial(x_temporal)))
|
||||
elif expert_permutation['temporal'] == 'caption':
|
||||
x_temporal = x_temporal + self.drop_path(self.mlp_cap(self.norm_cap(x_temporal)))
|
||||
elif expert_permutation['temporal'] == 'history':
|
||||
x_temporal = x_temporal + self.drop_path(self.mlp_hist(self.norm_hist(x_temporal)))
|
||||
elif expert_permutation['temporal'] == 'temporal':
|
||||
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
|
||||
else:
|
||||
x_temporal = x_temporal + self.drop_path(self.mlp_temp(self.norm_temp(x_temporal)))
|
||||
x_vis = torch.concat([x_spatial, x_temporal], dim=1)
|
||||
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
|
||||
else:
|
||||
x_vis = x[:, :vis_feat_len]
|
||||
x_vis = x_vis + self.drop_path(self.mlp_vis(self.norm_vis(x_vis)))
|
||||
|
||||
if self.has_hist:
|
||||
x_caption = x[:, -(cap_feat_len + hist_feat_len): -hist_feat_len]
|
||||
if expert_permutation is not None:
|
||||
if expert_permutation['caption'] == 'spatial':
|
||||
x_caption = x_caption + self.drop_path(self.mlp_spatial(self.norm_spatial(x_caption)))
|
||||
elif expert_permutation['caption'] == 'temporal':
|
||||
x_caption = x_caption + self.drop_path(self.mlp_temp(self.norm_temp(x_caption)))
|
||||
elif expert_permutation['caption'] == 'history':
|
||||
x_caption = x_caption + self.drop_path(self.mlp_hist(self.norm_hist(x_caption)))
|
||||
elif expert_permutation['caption'] == 'caption':
|
||||
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||
else:
|
||||
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||
|
||||
|
||||
x_history = x[:, -hist_feat_len:]
|
||||
if expert_permutation is not None:
|
||||
if expert_permutation['history'] == 'spatial':
|
||||
x_history = x_history + self.drop_path(self.mlp_spatial(self.norm_spatial(x_history)))
|
||||
elif expert_permutation['history'] == 'temporal':
|
||||
x_history = x_history + self.drop_path(self.mlp_temp(self.norm_temp(x_history)))
|
||||
elif expert_permutation['history'] == 'caption':
|
||||
x_history = x_history + self.drop_path(self.mlp_cap(self.norm_cap(x_history)))
|
||||
elif expert_permutation['history'] == 'history':
|
||||
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
|
||||
else:
|
||||
x_history = x_history + self.drop_path(self.mlp_hist(self.norm_hist(x_history)))
|
||||
# concat the features back
|
||||
x = torch.cat([x_vis, x_caption, x_history], dim=1)
|
||||
else:
|
||||
x_caption = x[:, -cap_feat_len:]
|
||||
x_caption = x_caption + self.drop_path(self.mlp_cap(self.norm_cap(x_caption)))
|
||||
x = torch.cat([x_vis, x_caption], dim=1)
|
||||
|
||||
assert x.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
|
||||
x.size(1), len_init
|
||||
)
|
||||
|
||||
elif expert_flag == 'fusion':
|
||||
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super(Pooler, self).__init__()
|
||||
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
pooled_states = hidden_states[:, 0]
|
||||
pooled_output = self.dense(pooled_states)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
234
models/backbones/moes_huggingface.py
Normal file
234
models/backbones/moes_huggingface.py
Normal file
|
@ -0,0 +1,234 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from timm.models.layers import DropPath
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
from .bert.xbert import BertLayer, BertAttention, BertIntermediate, BertOutput, BertConfig
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, config, expert_type):
|
||||
super(MoELayer, self).__init__()
|
||||
self.config = config
|
||||
self.expert_type = expert_type
|
||||
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||
|
||||
# Shared across all experts
|
||||
self.attention = BertAttention(self.bert_config)
|
||||
|
||||
# One for each expert
|
||||
if expert_type == 'modalities':
|
||||
# Spatial expert
|
||||
self.intermediate_spatial = BertIntermediate(self.bert_config)
|
||||
self.output_spatial = BertOutput(self.bert_config)
|
||||
|
||||
# Temporal expert
|
||||
self.intermediate_temporal = BertIntermediate(self.bert_config)
|
||||
self.output_temporal = BertOutput(self.bert_config)
|
||||
|
||||
# Vis Expert
|
||||
self.intermediate_vis = BertIntermediate(self.bert_config)
|
||||
self.output_vis = BertOutput(self.bert_config)
|
||||
|
||||
# Caption Expert
|
||||
self.intermediate_caption = BertIntermediate(self.bert_config)
|
||||
self.output_caption = BertOutput(self.bert_config)
|
||||
|
||||
if config.stage != 'stage_1':
|
||||
# History Expert
|
||||
self.intermediate_history = BertIntermediate(self.bert_config)
|
||||
self.output_history = BertOutput(self.bert_config)
|
||||
|
||||
# Fusion expert
|
||||
elif expert_type == 'fusion':
|
||||
self.intermediate_fusion = BertIntermediate(self.bert_config)
|
||||
self.output_fusion = BertOutput(self.bert_config)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for _, m in dict(self.named_modules()).items():
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.bias.data.zero_()
|
||||
m.weight.data.fill_(1.0)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
|
||||
if not (attention_mask.dim() == 2 and self.bert_config.is_decoder):
|
||||
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
||||
if device is not None:
|
||||
warnings.warn(
|
||||
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
)
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
# if self.config.is_decoder:
|
||||
# extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
||||
# input_shape, attention_mask, device
|
||||
# )
|
||||
# else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
def forward(self, hidden_states, special_toks_indices, expert_flag, mask=None, only_text=False, output_attentions=False):
|
||||
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
# dtype = mask.dtype
|
||||
# device = mask.device
|
||||
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape, dtype=torch.float32)
|
||||
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
head_mask=None
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
# outputs = self_attention_outputs[1:]
|
||||
|
||||
len_init = attention_output.size(1)
|
||||
# bs, h_dim = x.size(0), x.size(-1)
|
||||
# device = x.device
|
||||
|
||||
|
||||
if expert_flag == 'modalities':
|
||||
if only_text:
|
||||
intermediate_output = self.intermediate_caption(attention_output)
|
||||
layer_output = self.output_caption(intermediate_output, attention_output)
|
||||
else:
|
||||
# split the input first into different parts/modalities
|
||||
unchanged = attention_output[:, :special_toks_indices['<vis>'], :]
|
||||
end_idx_spatial = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
|
||||
attention_spatial = attention_output[:, special_toks_indices['<vis>']:end_idx_spatial, :]
|
||||
|
||||
end_idx_caption = special_toks_indices.get('<history>', special_toks_indices['</s>'] + 1)
|
||||
attention_caption = attention_output[:, special_toks_indices['<caption>']: end_idx_caption, :]
|
||||
|
||||
attention_temporal, attention_history = None, None
|
||||
|
||||
if '<temporal>' in special_toks_indices:
|
||||
end_idx_temporal = special_toks_indices['<caption>']
|
||||
attention_temporal = attention_output[:, special_toks_indices['<temporal>']:end_idx_temporal, :]
|
||||
|
||||
if '<history>' in special_toks_indices:
|
||||
end_idx_history = special_toks_indices['</s>'] + 1
|
||||
attention_history = attention_output[:, special_toks_indices['<history>']:end_idx_history, :]
|
||||
|
||||
# Expert activation
|
||||
# 1- Spatial
|
||||
intermediate_spatial = self.intermediate_spatial(attention_spatial)
|
||||
output_sapatial = self.output_spatial(intermediate_spatial, attention_spatial)
|
||||
|
||||
output_vis = output_sapatial
|
||||
|
||||
# 2- Temporal
|
||||
if attention_temporal is not None:
|
||||
intermediate_temporal = self.intermediate_temporal(attention_temporal)
|
||||
output_temporal = self.output_temporal(intermediate_temporal, attention_temporal)
|
||||
|
||||
attention_vis = torch.concat([output_sapatial, output_temporal], dim=1)
|
||||
intermediate_vis = self.intermediate_vis(attention_vis)
|
||||
output_vis = self.output_vis(intermediate_vis, attention_vis)
|
||||
|
||||
# 3- Caption
|
||||
intermediate_caption = self.intermediate_caption(attention_caption)
|
||||
output_caption = self.output_caption(intermediate_caption, attention_caption)
|
||||
|
||||
# 4- History
|
||||
if attention_history is not None:
|
||||
intermediate_history = self.intermediate_history(attention_history)
|
||||
output_history = self.output_history(intermediate_history, attention_history)
|
||||
|
||||
output_list = [unchanged, output_vis, output_caption]
|
||||
|
||||
if attention_history is not None:
|
||||
output_list.append(output_history)
|
||||
|
||||
# Concat the features back
|
||||
layer_output = torch.concat(output_list, dim=1)
|
||||
assert layer_output.size(1) == len_init, 'Reconstructed features length is {} != original features len = {}'.format(
|
||||
layer_output.size(1), len_init
|
||||
)
|
||||
|
||||
elif expert_flag == 'fusion':
|
||||
intermediate_output = self.intermediate_fusion(attention_output)
|
||||
layer_output = self.output_fusion(intermediate_output, attention_output)
|
||||
|
||||
return layer_output
|
||||
|
||||
|
||||
class MoEPooler(nn.Module):
|
||||
def __init__(self):
|
||||
super(MoEPooler, self).__init__()
|
||||
|
||||
self.bert_config = BertConfig.from_pretrained('bert-large-uncased')
|
||||
hidden_size = self.bert_config.hidden_size
|
||||
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for _, m in dict(self.named_modules()).items():
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.bias.data.zero_()
|
||||
m.weight.data.fill_(1.0)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, hidden_states, idx):
|
||||
pooled_states = hidden_states[:, idx]
|
||||
pooled_output = self.dense(pooled_states)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
247
models/backbones/moes_original.py
Normal file
247
models/backbones/moes_original.py
Normal file
|
@ -0,0 +1,247 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
if mask is not None:
|
||||
if mask.dim() != x.dim():
|
||||
expanded_mask = mask[:, None, None, :].expand(B, 1, N, N)
|
||||
else:
|
||||
expanded_mask = mask
|
||||
expanded_mask = expanded_mask.bool()
|
||||
attn = attn.masked_fill(~expanded_mask, float("-inf"))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.SiLU,
|
||||
norm_layer=LlamaRMSNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_att = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# EXPERT CONSTRUCTION
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
|
||||
# Spatial expert
|
||||
self.norm_spatial = norm_layer(dim)
|
||||
self.mlp_spatial = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Temporal expert
|
||||
self.norm_temp = norm_layer(dim)
|
||||
self.mlp_temp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Vis expert
|
||||
self.norm_vis = norm_layer(dim)
|
||||
self.mlp_vis = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# caption expert
|
||||
self.norm_cap = norm_layer(dim)
|
||||
self.mlp_cap = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# history expert
|
||||
self.norm_hist = norm_layer(dim)
|
||||
self.mlp_hist = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Fusion expert
|
||||
self.norm_fusion = norm_layer(dim)
|
||||
self.mlp_fusion = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# expert_flag:{Only Text : 00 , Only Image : 01, Fusion : 10, Text & Image : 11} (BINARY)
|
||||
|
||||
# expert_flag:
|
||||
# 0:
|
||||
|
||||
def forward(self, x, special_toks_indices, expert_flag, mask=None):
|
||||
x_shortcut, attn = self.attn(self.norm_att(x), mask=mask)
|
||||
x = x + self.drop_path(x_shortcut)
|
||||
bs, h_dim = x.size(0), x.size(-1)
|
||||
device = x.device
|
||||
|
||||
if expert_flag == 'modalities':
|
||||
end_index = special_toks_indices.get('<temporal>', special_toks_indices['<caption>'])
|
||||
spatial_feats = x[:, special_toks_indices['<spatial>']: end_index, :]
|
||||
spatial_feats = spatial_feats + self.drop_path(self.mlp_spatial(self.norm_spatial(spatial_feats)))
|
||||
spatial_index = torch.arange(special_toks_indices['<spatial>'], end_index, device=device)
|
||||
spatial_index = spatial_index.unsqueeze(0).unsqueeze(-1)
|
||||
spatial_index = spatial_index.repeat(bs, 1, h_dim)
|
||||
x = x.scatter(1, spatial_index, spatial_feats)
|
||||
# x[:, special_toks_indices['<spatial>']: special_toks_indices['<temporal>'], :] = spatial_feats
|
||||
|
||||
end_index = special_toks_indices.get('<history>', special_toks_indices['</s>'])
|
||||
caption_feats = x[:, special_toks_indices['<caption>']: end_index, :]
|
||||
caption_feats = caption_feats + self.drop_path(self.mlp_cap(self.norm_cap(caption_feats)))
|
||||
caption_index = torch.arange(special_toks_indices['<caption>'], end_index, device=device)
|
||||
caption_index = caption_index.unsqueeze(0).unsqueeze(-1)
|
||||
caption_index = caption_index.repeat(bs, 1, h_dim)
|
||||
x = x.scatter(1, caption_index, caption_feats)
|
||||
|
||||
# x[:, special_toks_indices['<caption>']: special_toks_indices['</s>'], :] = caption_feats
|
||||
|
||||
if '<temporal>' in special_toks_indices:
|
||||
temporal_feats = x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :]
|
||||
temporal_feats = temporal_feats + self.drop_path(self.mlp_temp(self.norm_temp(temporal_feats)))
|
||||
temporal_index = torch.arange(special_toks_indices['<temporal>'], special_toks_indices['<caption>'], device=device)
|
||||
temporal_index = temporal_index.unsqueeze(0).unsqueeze(-1)
|
||||
temporal_index = temporal_index.repeat(bs, 1, h_dim)
|
||||
x = x.scatter(1, temporal_index, temporal_feats)
|
||||
|
||||
# x[:, special_toks_indices['<temporal>']: special_toks_indices['<caption>'], :] = temporal_feats
|
||||
|
||||
vis_feats = x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :]
|
||||
vis_feats = vis_feats + self.drop_path(self.mlp_vis(self.norm_vis(vis_feats)))
|
||||
vis_index = torch.arange(special_toks_indices['<vis>'], special_toks_indices['<caption>'], device=device)
|
||||
vis_index = vis_index.unsqueeze(0).unsqueeze(-1)
|
||||
vis_index = vis_index.repeat(bs, 1, h_dim)
|
||||
x = x.scatter(1, vis_index, vis_feats)
|
||||
|
||||
# x[:, special_toks_indices['<vis>']: special_toks_indices['<caption>'], :] = vis_feats
|
||||
|
||||
if '<history>' in special_toks_indices:
|
||||
history_feats = x[:, special_toks_indices['<history>']: special_toks_indices['</s>'], :]
|
||||
history_feats = history_feats + self.drop_path(self.mlp_hist(self.norm_hist(history_feats)))
|
||||
history_index = torch.arange(special_toks_indices['<history>'], special_toks_indices['</s>'], device=device)
|
||||
history_index = history_index.unsqueeze(0).unsqueeze(-1)
|
||||
history_index = history_index.repeat(bs, 1, h_dim)
|
||||
x = x.scatter(1, history_index, history_feats)
|
||||
|
||||
elif expert_flag == 'fusion':
|
||||
x = x + self.drop_path(self.mlp_fusion(self.norm_fusion(x)))
|
||||
|
||||
return x, attn
|
||||
|
||||
# if expert_flag == 2:
|
||||
# x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
# elif expert_flag == 0:
|
||||
# x = (x[:, -it_split:])
|
||||
# x = x + self.drop_path(self.sentence_mlp(self.sentence_norm(x)))
|
||||
# elif expert_flag == 1:
|
||||
# x = (x[:, :-it_split ])
|
||||
# x = x + self.drop_path(self.image_mlp(self.image_norm(x)))
|
||||
# elif expert_flag == 3:
|
||||
# text, image = (x[:, :it_split], x[:, it_split:],)
|
||||
# text = text + self.drop_path(self.sentence_mlp(self.sentence_norm(text)))
|
||||
# image = image + self.drop_path(self.image_mlp(self.image_norm(image)))
|
||||
# x = torch.cat([text, image], dim=1)
|
||||
# elif expert_flag == 4:
|
||||
# x = x + self.drop_path(self.generation_mlp(self.generation_norm(x)))
|
||||
# return x, attn
|
Loading…
Add table
Add a link
Reference in a new issue