266 lines
9.8 KiB
Python
266 lines
9.8 KiB
Python
import logging
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from scipy import interpolate
|
|
from typing import List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MLM:
|
|
def __init__(
|
|
self,
|
|
mask_token: int,
|
|
padding_token: int,
|
|
no_mask_tokens: List[int],
|
|
n_tokens: int,
|
|
masking_prob: float = 0.15,
|
|
randomize_prob: float = 0.1,
|
|
no_change_prob: float = 0.1
|
|
):
|
|
self.mask_token = mask_token
|
|
self.padding_token = padding_token
|
|
self.no_mask_tokens = list(set(no_mask_tokens + [padding_token, mask_token]))
|
|
self.n_tokens = n_tokens
|
|
self.masking_prob = masking_prob
|
|
self.randomize_prob = randomize_prob
|
|
self.no_change_prob = no_change_prob
|
|
|
|
def __call__(self, x: torch.Tensor):
|
|
full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
|
|
for tok in self.no_mask_tokens:
|
|
full_mask &= x != tok # unmask unwanted tokens --> 0
|
|
|
|
unchanged_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
|
|
random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
|
|
random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
|
|
random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
|
|
mask = full_mask & ~random_token_mask & ~unchanged_mask
|
|
|
|
y = x.clone().detach()
|
|
x.masked_fill_(mask, self.mask_token)
|
|
x[random_token_idx] = random_tokens
|
|
y.masked_fill_(~full_mask, self.padding_token)
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
def _init_transformer_weights(module, initializer_range=0.02):
|
|
"""Initialize the weights. Copied from transformers ViT/Bert model init"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
|
|
"""
|
|
Add/Remove extra temporal_embeddings as needed.
|
|
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
|
|
|
|
temp_embed_old: (1, num_frames_old, 1, d)
|
|
temp_embed_new: (1, num_frames_new, 1, d)
|
|
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
|
|
"""
|
|
# TODO zero pad
|
|
num_frms_new = temp_embed_new.shape[1]
|
|
num_frms_old = temp_embed_old.shape[1]
|
|
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
|
|
if num_frms_new > num_frms_old:
|
|
if add_zero:
|
|
temp_embed_new[
|
|
:, :num_frms_old
|
|
] = temp_embed_old # untrained embeddings are zeros.
|
|
else:
|
|
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
|
|
elif num_frms_new < num_frms_old:
|
|
temp_embed_new = temp_embed_old[:, :num_frms_new]
|
|
else: # =
|
|
temp_embed_new = temp_embed_old
|
|
return temp_embed_new
|
|
|
|
|
|
def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
|
|
"""
|
|
temp_embed_old: (1, num_frames_old, 1, d)
|
|
Returns:
|
|
temp_embed_new: (1, num_frames_new, 1, d)
|
|
"""
|
|
temp_embed_old = temp_embed_old.squeeze(2).permute(
|
|
0, 2, 1
|
|
) # (1, d, num_frames_old)
|
|
temp_embed_new = F.interpolate(
|
|
temp_embed_old, num_frames_new, mode="linear"
|
|
) # (1, d, num_frames_new)
|
|
temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
|
|
2
|
|
) # (1, num_frames_new, 1, d)
|
|
return temp_embed_new
|
|
|
|
|
|
def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
|
|
"""
|
|
Args:
|
|
pos_embed_old: (1, L_old, d), pre-trained
|
|
pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
|
|
num_patches_new:
|
|
"""
|
|
# interpolate position embedding
|
|
embedding_size = pos_embed_old.shape[-1]
|
|
num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
|
|
# height (== width) for the checkpoint position embedding
|
|
orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
|
|
# height (== width) for the new position embedding
|
|
new_size = int(num_patches_new ** 0.5)
|
|
|
|
if orig_size != new_size:
|
|
# class_token and dist_token are kept unchanged
|
|
# the extra tokens seems always at the beginning of the position embedding
|
|
extra_tokens = pos_embed_old[:, :num_extra_tokens]
|
|
# only the position tokens are interpolated
|
|
pos_tokens = pos_embed_old[:, num_extra_tokens:]
|
|
pos_tokens = pos_tokens.reshape(
|
|
-1, orig_size, orig_size, embedding_size
|
|
).permute(0, 3, 1, 2)
|
|
pos_tokens = torch.nn.functional.interpolate(
|
|
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
|
)
|
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
|
|
return interpolated_pos_embed
|
|
else:
|
|
return pos_embed_old
|
|
|
|
|
|
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
|
|
"""
|
|
Args:
|
|
state_dict_old: loaded state dict
|
|
state_dict_new: state dict for model with new image size
|
|
patch_shape_new: new model patch_shape
|
|
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
|
|
"""
|
|
all_keys = list(state_dict_old.keys())
|
|
for key in all_keys:
|
|
if "relative_position_index" in key:
|
|
state_dict_old.pop(key)
|
|
|
|
if "relative_position_bias_table" in key:
|
|
rel_pos_bias = state_dict_old[key]
|
|
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
|
dst_num_pos, _ = state_dict_new[key].size()
|
|
dst_patch_shape = patch_shape_new
|
|
if dst_patch_shape[0] != dst_patch_shape[1]:
|
|
raise NotImplementedError()
|
|
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
|
|
dst_patch_shape[1] * 2 - 1
|
|
)
|
|
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
|
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
|
if src_size != dst_size:
|
|
# logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
|
|
# key, src_size, src_size, dst_size, dst_size))
|
|
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
|
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
|
|
|
def geometric_progression(a, r, n):
|
|
return a * (1.0 - r ** n) / (1.0 - r)
|
|
|
|
left, right = 1.01, 1.5
|
|
while right - left > 1e-6:
|
|
q = (left + right) / 2.0
|
|
gp = geometric_progression(1, q, src_size // 2)
|
|
if gp > dst_size // 2:
|
|
right = q
|
|
else:
|
|
left = q
|
|
|
|
# if q > 1.090307:
|
|
# q = 1.090307
|
|
|
|
dis = []
|
|
cur = 1
|
|
for i in range(src_size // 2):
|
|
dis.append(cur)
|
|
cur += q ** (i + 1)
|
|
|
|
r_ids = [-_ for _ in reversed(dis)]
|
|
|
|
x = r_ids + [0] + dis
|
|
y = r_ids + [0] + dis
|
|
|
|
t = dst_size // 2.0
|
|
dx = np.arange(-t, t + 0.1, 1.0)
|
|
dy = np.arange(-t, t + 0.1, 1.0)
|
|
|
|
# logger.info("Original positions = %s" % str(x))
|
|
# logger.info("Target positions = %s" % str(dx))
|
|
|
|
all_rel_pos_bias = []
|
|
|
|
for i in range(num_attn_heads):
|
|
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
|
f = interpolate.interp2d(x, y, z, kind="cubic")
|
|
all_rel_pos_bias.append(
|
|
torch.Tensor(f(dx, dy))
|
|
.contiguous()
|
|
.view(-1, 1)
|
|
.to(rel_pos_bias.device)
|
|
)
|
|
|
|
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
|
|
|
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
|
state_dict_old[key] = new_rel_pos_bias
|
|
return state_dict_old
|
|
|
|
|
|
def tile(x, dim, n_tile):
|
|
init_dim = x.size(dim)
|
|
repeat_idx = [1] * x.dim()
|
|
repeat_idx[dim] = n_tile
|
|
x = x.repeat(*repeat_idx)
|
|
order_index = torch.LongTensor(
|
|
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
|
)
|
|
return torch.index_select(x, dim, order_index.to(x.device))
|
|
|
|
|
|
def mask_logits(target, mask):
|
|
return target * mask + (1 - mask) * (-1e10)
|
|
|
|
|
|
class AllGather(torch.autograd.Function):
|
|
"""An autograd function that performs allgather on a tensor."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, tensor, args):
|
|
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
|
torch.distributed.all_gather(output, tensor)
|
|
ctx.rank = args.rank
|
|
ctx.batch_size = tensor.shape[0]
|
|
return torch.cat(output, dim=0)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return (
|
|
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
|
|
None,
|
|
)
|
|
|
|
|
|
allgather_wgrad = AllGather.apply
|