initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
247
models/backbones/base_model.py
Executable file
247
models/backbones/base_model.py
Executable file
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||
from models.common.utils import get_abs_path, is_url
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""Base class for models."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
def load_checkpoint(self, url_or_filename):
|
||||
"""
|
||||
Load from a finetuned checkpoint.
|
||||
|
||||
This should expect no mismatch in the model keys and the checkpoint keys.
|
||||
"""
|
||||
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
url_or_filename, check_hash=False, progress=True
|
||||
)
|
||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||
else:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
if "model" in checkpoint.keys():
|
||||
state_dict = checkpoint["model"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
msg = self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
logging.info("Missing keys {}".format(msg.missing_keys))
|
||||
logging.info("load checkpoint from %s" % url_or_filename)
|
||||
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_type):
|
||||
"""
|
||||
Build a pretrained model from default configuration file, specified by model_type.
|
||||
|
||||
Args:
|
||||
- model_type (str): model type, specifying architecture and checkpoints.
|
||||
|
||||
Returns:
|
||||
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
||||
"""
|
||||
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
||||
model = cls.from_config(model_cfg)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, model_type):
|
||||
assert (
|
||||
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
||||
), "Unknown model type {}".format(model_type)
|
||||
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
||||
|
||||
def load_checkpoint_from_config(self, cfg, **kwargs):
|
||||
"""
|
||||
Load checkpoint as specified in the config file.
|
||||
|
||||
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
||||
When loading the pretrained model, each task-specific architecture may define their
|
||||
own load_from_pretrained() method.
|
||||
"""
|
||||
load_finetuned = cfg.get("load_finetuned", True)
|
||||
if load_finetuned:
|
||||
finetune_path = cfg.get("finetuned", None)
|
||||
assert (
|
||||
finetune_path is not None
|
||||
), "Found load_finetuned is True, but finetune_path is None."
|
||||
self.load_checkpoint(url_or_filename=finetune_path)
|
||||
else:
|
||||
# load pre-trained weights
|
||||
pretrain_path = cfg.get("pretrained", None)
|
||||
assert "Found load_finetuned is False, but pretrain_path is None."
|
||||
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
||||
|
||||
def before_evaluation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def show_n_params(self, return_str=True):
|
||||
tot = 0
|
||||
for p in self.parameters():
|
||||
w = 1
|
||||
for x in p.shape:
|
||||
w *= x
|
||||
tot += w
|
||||
if return_str:
|
||||
if tot >= 1e6:
|
||||
return "{:.1f}M".format(tot / 1e6)
|
||||
else:
|
||||
return "{:.1f}K".format(tot / 1e3)
|
||||
else:
|
||||
return tot
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""
|
||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_features(self, samples, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
|
||||
class SharedQueueMixin:
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
||||
# gather keys before updating queue
|
||||
image_feats = concat_all_gather(image_feat)
|
||||
text_feats = concat_all_gather(text_feat)
|
||||
|
||||
batch_size = image_feats.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.queue_size % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
||||
|
||||
if idxs is not None:
|
||||
idxs = concat_all_gather(idxs)
|
||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
||||
|
||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
|
||||
class MomentumDistilationMixin:
|
||||
@torch.no_grad()
|
||||
def copy_params(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data.copy_(param.data) # initialize
|
||||
param_m.requires_grad = False # not update by gradient
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data = param_m.data * self.momentum + param.data * (
|
||||
1.0 - self.momentum
|
||||
)
|
||||
|
||||
|
||||
class GatherLayer(torch.autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [
|
||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
torch.distributed.all_reduce(all_gradients)
|
||||
return all_gradients[torch.distributed.get_rank()]
|
||||
|
||||
|
||||
def all_gather_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = torch.distributed.get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
|
||||
# tensor_all = GatherLayer.apply(tensors)
|
||||
tensor_all = GatherLayer.apply(tensors)
|
||||
|
||||
return torch.cat(tensor_all, dim=0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
# if use distributed training
|
||||
if not is_dist_avail_and_initialized():
|
||||
return tensor
|
||||
|
||||
tensors_gather = [
|
||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def tile(x, dim, n_tile):
|
||||
init_dim = x.size(dim)
|
||||
repeat_idx = [1] * x.dim()
|
||||
repeat_idx[dim] = n_tile
|
||||
x = x.repeat(*(repeat_idx))
|
||||
order_index = torch.LongTensor(
|
||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||
)
|
||||
return torch.index_select(x, dim, order_index.to(x.device))
|
Loading…
Add table
Add a link
Reference in a new issue