initial commit
This commit is contained in:
commit
7be61f8c6d
137 changed files with 33491 additions and 0 deletions
266
models/utils.py
Normal file
266
models/utils.py
Normal file
|
@ -0,0 +1,266 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy import interpolate
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLM:
|
||||
def __init__(
|
||||
self,
|
||||
mask_token: int,
|
||||
padding_token: int,
|
||||
no_mask_tokens: List[int],
|
||||
n_tokens: int,
|
||||
masking_prob: float = 0.15,
|
||||
randomize_prob: float = 0.1,
|
||||
no_change_prob: float = 0.1
|
||||
):
|
||||
self.mask_token = mask_token
|
||||
self.padding_token = padding_token
|
||||
self.no_mask_tokens = list(set(no_mask_tokens + [padding_token, mask_token]))
|
||||
self.n_tokens = n_tokens
|
||||
self.masking_prob = masking_prob
|
||||
self.randomize_prob = randomize_prob
|
||||
self.no_change_prob = no_change_prob
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
|
||||
for tok in self.no_mask_tokens:
|
||||
full_mask &= x != tok # unmask unwanted tokens --> 0
|
||||
|
||||
unchanged_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
|
||||
random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
|
||||
random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
|
||||
random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
|
||||
mask = full_mask & ~random_token_mask & ~unchanged_mask
|
||||
|
||||
y = x.clone().detach()
|
||||
x.masked_fill_(mask, self.mask_token)
|
||||
x[random_token_idx] = random_tokens
|
||||
y.masked_fill_(~full_mask, self.padding_token)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
|
||||
def _init_transformer_weights(module, initializer_range=0.02):
|
||||
"""Initialize the weights. Copied from transformers ViT/Bert model init"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
|
||||
"""
|
||||
Add/Remove extra temporal_embeddings as needed.
|
||||
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
|
||||
|
||||
temp_embed_old: (1, num_frames_old, 1, d)
|
||||
temp_embed_new: (1, num_frames_new, 1, d)
|
||||
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
|
||||
"""
|
||||
# TODO zero pad
|
||||
num_frms_new = temp_embed_new.shape[1]
|
||||
num_frms_old = temp_embed_old.shape[1]
|
||||
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
|
||||
if num_frms_new > num_frms_old:
|
||||
if add_zero:
|
||||
temp_embed_new[
|
||||
:, :num_frms_old
|
||||
] = temp_embed_old # untrained embeddings are zeros.
|
||||
else:
|
||||
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
|
||||
elif num_frms_new < num_frms_old:
|
||||
temp_embed_new = temp_embed_old[:, :num_frms_new]
|
||||
else: # =
|
||||
temp_embed_new = temp_embed_old
|
||||
return temp_embed_new
|
||||
|
||||
|
||||
def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
|
||||
"""
|
||||
temp_embed_old: (1, num_frames_old, 1, d)
|
||||
Returns:
|
||||
temp_embed_new: (1, num_frames_new, 1, d)
|
||||
"""
|
||||
temp_embed_old = temp_embed_old.squeeze(2).permute(
|
||||
0, 2, 1
|
||||
) # (1, d, num_frames_old)
|
||||
temp_embed_new = F.interpolate(
|
||||
temp_embed_old, num_frames_new, mode="linear"
|
||||
) # (1, d, num_frames_new)
|
||||
temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
|
||||
2
|
||||
) # (1, num_frames_new, 1, d)
|
||||
return temp_embed_new
|
||||
|
||||
|
||||
def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
|
||||
"""
|
||||
Args:
|
||||
pos_embed_old: (1, L_old, d), pre-trained
|
||||
pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
|
||||
num_patches_new:
|
||||
"""
|
||||
# interpolate position embedding
|
||||
embedding_size = pos_embed_old.shape[-1]
|
||||
num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches_new ** 0.5)
|
||||
|
||||
if orig_size != new_size:
|
||||
# class_token and dist_token are kept unchanged
|
||||
# the extra tokens seems always at the beginning of the position embedding
|
||||
extra_tokens = pos_embed_old[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_old[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
|
||||
return interpolated_pos_embed
|
||||
else:
|
||||
return pos_embed_old
|
||||
|
||||
|
||||
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
|
||||
"""
|
||||
Args:
|
||||
state_dict_old: loaded state dict
|
||||
state_dict_new: state dict for model with new image size
|
||||
patch_shape_new: new model patch_shape
|
||||
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
|
||||
"""
|
||||
all_keys = list(state_dict_old.keys())
|
||||
for key in all_keys:
|
||||
if "relative_position_index" in key:
|
||||
state_dict_old.pop(key)
|
||||
|
||||
if "relative_position_bias_table" in key:
|
||||
rel_pos_bias = state_dict_old[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = state_dict_new[key].size()
|
||||
dst_patch_shape = patch_shape_new
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1
|
||||
)
|
||||
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
||||
if src_size != dst_size:
|
||||
# logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
|
||||
# key, src_size, src_size, dst_size, dst_size))
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r ** n) / (1.0 - r)
|
||||
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
|
||||
# if q > 1.090307:
|
||||
# q = 1.090307
|
||||
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q ** (i + 1)
|
||||
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
|
||||
# logger.info("Original positions = %s" % str(x))
|
||||
# logger.info("Target positions = %s" % str(dx))
|
||||
|
||||
all_rel_pos_bias = []
|
||||
|
||||
for i in range(num_attn_heads):
|
||||
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind="cubic")
|
||||
all_rel_pos_bias.append(
|
||||
torch.Tensor(f(dx, dy))
|
||||
.contiguous()
|
||||
.view(-1, 1)
|
||||
.to(rel_pos_bias.device)
|
||||
)
|
||||
|
||||
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
||||
|
||||
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
||||
state_dict_old[key] = new_rel_pos_bias
|
||||
return state_dict_old
|
||||
|
||||
|
||||
def tile(x, dim, n_tile):
|
||||
init_dim = x.size(dim)
|
||||
repeat_idx = [1] * x.dim()
|
||||
repeat_idx[dim] = n_tile
|
||||
x = x.repeat(*repeat_idx)
|
||||
order_index = torch.LongTensor(
|
||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||
)
|
||||
return torch.index_select(x, dim, order_index.to(x.device))
|
||||
|
||||
|
||||
def mask_logits(target, mask):
|
||||
return target * mask + (1 - mask) * (-1e10)
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
"""An autograd function that performs allgather on a tensor."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, args):
|
||||
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(output, tensor)
|
||||
ctx.rank = args.rank
|
||||
ctx.batch_size = tensor.shape[0]
|
||||
return torch.cat(output, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return (
|
||||
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
allgather_wgrad = AllGather.apply
|
Loading…
Add table
Add a link
Reference in a new issue