286 lines
8.5 KiB
Python
286 lines
8.5 KiB
Python
import logging
|
|
import math
|
|
|
|
import einops
|
|
import torch
|
|
from einops import rearrange
|
|
from timm.models.layers.drop import DropPath
|
|
from torch import nn
|
|
from torch.nn import LayerNorm, Linear, MultiheadAttention
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class STAdapter(nn.Module):
|
|
"""ST Adapter"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_size=(3, 3, 3),
|
|
input_dim=768,
|
|
hidden_dim=384,
|
|
img_size=224,
|
|
patch_size=16,
|
|
drop_prob=0.1,
|
|
):
|
|
super(STAdapter, self).__init__()
|
|
self.kernel_size = kernel_size
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
|
|
self.h = self.w = img_size // patch_size
|
|
|
|
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, input_dim)
|
|
self.act = nn.ReLU()
|
|
self.conv = nn.Conv3d(
|
|
hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim
|
|
)
|
|
self.droppath = DropPath(drop_prob=drop_prob)
|
|
|
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""forward
|
|
|
|
Args:
|
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
|
|
|
Returns: features after adapter. The same shape as input.
|
|
|
|
"""
|
|
if x.shape[1] == 1: # for single frame, return itself.
|
|
return x
|
|
|
|
shortcut = x
|
|
x = self.linear1(x)
|
|
cls = x[:, :, :1, :]
|
|
tokens = x[:, :, 1:, :]
|
|
tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous()
|
|
tokens = self.conv(tokens)
|
|
tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c")
|
|
x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c]
|
|
x = self.act(x)
|
|
x = self.linear2(x)
|
|
|
|
return shortcut + self.scale * self.droppath(x)
|
|
|
|
|
|
class SpatialAttention(nn.Module):
|
|
"""Perfrom spatial self-attention"""
|
|
|
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
|
super(SpatialAttention, self).__init__()
|
|
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
|
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
|
self.linear = Linear(input_dim, input_dim)
|
|
self.droppath = DropPath(droppath_rate)
|
|
# self.scale = nn.parameter.Parameter(torch.zeros([]))
|
|
self.scale = 1.0
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if x.shape[1] == 1:
|
|
x = self.norm(x)
|
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
|
return x # return self if media is image
|
|
|
|
shortcut = x
|
|
x = einops.rearrange(x, 'b t l c -> (b t) l c')
|
|
x = self.norm(x)
|
|
x = self.attn(x, x, x)[0]
|
|
x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0])
|
|
x = shortcut + self.scale * self.droppath(x)
|
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
|
return x
|
|
|
|
|
|
class TemporalAttention(nn.Module):
|
|
|
|
"""perform temporal self-attention"""
|
|
|
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
|
"""
|
|
|
|
Kwargs:
|
|
input_dim (int): The input feature dimension.
|
|
|
|
|
|
"""
|
|
super(TemporalAttention, self).__init__()
|
|
|
|
self._input_dim = input_dim
|
|
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
|
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
|
self.linear = Linear(input_dim, input_dim)
|
|
self.droppath = DropPath(droppath_rate)
|
|
# self.scale = nn.parameter.Parameter(torch.zeros([]))
|
|
self.scale = 1.0
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""forward
|
|
|
|
Args:
|
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
|
|
|
Returns: features after adapter. The same shape as input.
|
|
|
|
"""
|
|
if x.shape[1] == 1: # for single frame, return itself.
|
|
x = self.norm(x)
|
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
|
return x
|
|
|
|
shortcut = x
|
|
x = einops.rearrange(x, "b t l c -> (b l) t c")
|
|
x = self.norm(x)
|
|
x = self.attn(x, x, x)[0]
|
|
x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0])
|
|
x = shortcut + self.scale * self.droppath(x)
|
|
x = einops.rearrange(x, "b t l c -> b (t l) c")
|
|
return x
|
|
|
|
|
|
class WindowTemporalAttention(nn.Module):
|
|
|
|
"""perform windowed temporal self-attention"""
|
|
|
|
def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)):
|
|
"""
|
|
|
|
Kwargs:
|
|
input_dim (int): The input feature dimension.
|
|
|
|
|
|
"""
|
|
super().__init__()
|
|
|
|
self._input_dim = input_dim
|
|
self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64)
|
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
|
self.droppath = DropPath(droppath_rate)
|
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
|
self.wh, self.ww = window_size
|
|
# logger.info(f"WindowTemporalAttention: window_size: {window_size}")
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""forward
|
|
|
|
Args:
|
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
|
|
|
Returns: features after adapter. The same shape as input.
|
|
|
|
"""
|
|
if x.shape[1] == 1: # for single frame, return itself.
|
|
return x
|
|
shortcut = x
|
|
|
|
h = w = int(math.sqrt(x.shape[2] - 1))
|
|
cls_token = x[:, :, :1, :]
|
|
x = einops.rearrange(
|
|
x[:, :, 1:, :],
|
|
"b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c",
|
|
nh=h // self.wh,
|
|
wh=self.wh,
|
|
nw=w // self.ww,
|
|
ww=self.ww,
|
|
)
|
|
x = self.norm(x)
|
|
x = self.temporal_attn(x, x, x)[0]
|
|
x = einops.rearrange(
|
|
x,
|
|
"(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c",
|
|
wh=self.wh,
|
|
ww=self.ww,
|
|
nh=h // self.wh,
|
|
nw=w // self.ww,
|
|
)
|
|
# add back cls token.
|
|
x = torch.concat([cls_token, x], dim=2)
|
|
return shortcut + self.scale * self.droppath(x)
|
|
|
|
|
|
class X_CLIP(nn.Module):
|
|
|
|
"""perform windowed temporal self-attention"""
|
|
|
|
def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1):
|
|
"""
|
|
|
|
Kwargs:
|
|
input_dim (int): The input feature dimension.
|
|
|
|
|
|
"""
|
|
super().__init__()
|
|
|
|
d_model = input_dim
|
|
|
|
self.message_fc = nn.Linear(d_model, d_model)
|
|
self.message_ln = LayerNorm(d_model, eps=1e-12)
|
|
self.message_attn = nn.MultiheadAttention(d_model, d_model // 64)
|
|
self.num_prompts = num_prompts
|
|
|
|
self.droppath = DropPath(droppath_rate)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""forward
|
|
|
|
Args:
|
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
|
|
|
Returns: features after adapter. The same shape as input.
|
|
|
|
"""
|
|
if x.shape[1] == 1: # for single frame, return itself.
|
|
return x
|
|
msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c]
|
|
msg_token = rearrange(msg_token, "b t c -> t b c")
|
|
msg_token = msg_token + self.droppath(
|
|
self.message_attn(msg_token, msg_token, msg_token)[0]
|
|
)
|
|
msg_token = rearrange(msg_token, "t b c -> b t c")
|
|
# replace the last prompt token with msg_token.
|
|
x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c]
|
|
return x
|
|
|
|
|
|
class TemporalS4(nn.Module):
|
|
|
|
"""perform temporal self-attention"""
|
|
|
|
def __init__(self, input_dim=768, droppath_rate=0.1):
|
|
"""
|
|
|
|
Kwargs:
|
|
input_dim (int): The input feature dimension.
|
|
|
|
|
|
"""
|
|
super().__init__()
|
|
from .s4 import S4
|
|
|
|
self._input_dim = input_dim
|
|
self.norm = LayerNorm(input_dim, eps=1e-12)
|
|
self.droppath = DropPath(droppath_rate)
|
|
self.scale = nn.parameter.Parameter(torch.zeros([]))
|
|
self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""forward
|
|
|
|
Args:
|
|
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
|
|
|
|
Returns: features after adapter. The same shape as input.
|
|
|
|
"""
|
|
if x.shape[1] == 1: # for single frame, return itself.
|
|
return x
|
|
|
|
shortcut = x
|
|
x = self.norm(x)
|
|
x = einops.rearrange(x, "b t l c -> b c (t l)")
|
|
x, _ = self.s4(x)
|
|
x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1])
|
|
return shortcut + self.scale * self.droppath(x)
|