import logging import math import einops import torch from einops import rearrange from timm.models.layers.drop import DropPath from torch import nn from torch.nn import LayerNorm, Linear, MultiheadAttention logger = logging.getLogger(__name__) class STAdapter(nn.Module): """ST Adapter""" def __init__( self, kernel_size=(3, 3, 3), input_dim=768, hidden_dim=384, img_size=224, patch_size=16, drop_prob=0.1, ): super(STAdapter, self).__init__() self.kernel_size = kernel_size self.input_dim = input_dim self.hidden_dim = hidden_dim self.h = self.w = img_size // patch_size self.linear1 = nn.Linear(input_dim, hidden_dim) self.linear2 = nn.Linear(hidden_dim, input_dim) self.act = nn.ReLU() self.conv = nn.Conv3d( hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim ) self.droppath = DropPath(drop_prob=drop_prob) self.scale = nn.parameter.Parameter(torch.zeros([])) def forward(self, x: torch.Tensor): """forward Args: x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w Returns: features after adapter. The same shape as input. """ if x.shape[1] == 1: # for single frame, return itself. return x shortcut = x x = self.linear1(x) cls = x[:, :, :1, :] tokens = x[:, :, 1:, :] tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous() tokens = self.conv(tokens) tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c") x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c] x = self.act(x) x = self.linear2(x) return shortcut + self.scale * self.droppath(x) class SpatialAttention(nn.Module): """Perfrom spatial self-attention""" def __init__(self, input_dim=768, droppath_rate=0.1): super(SpatialAttention, self).__init__() self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True) self.norm = LayerNorm(input_dim, eps=1e-12) self.linear = Linear(input_dim, input_dim) self.droppath = DropPath(droppath_rate) # self.scale = nn.parameter.Parameter(torch.zeros([])) self.scale = 1.0 def forward(self, x: torch.Tensor): if x.shape[1] == 1: x = self.norm(x) x = einops.rearrange(x, "b t l c -> b (t l) c") return x # return self if media is image shortcut = x x = einops.rearrange(x, 'b t l c -> (b t) l c') x = self.norm(x) x = self.attn(x, x, x)[0] x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0]) x = shortcut + self.scale * self.droppath(x) x = einops.rearrange(x, "b t l c -> b (t l) c") return x class TemporalAttention(nn.Module): """perform temporal self-attention""" def __init__(self, input_dim=768, droppath_rate=0.1): """ Kwargs: input_dim (int): The input feature dimension. """ super(TemporalAttention, self).__init__() self._input_dim = input_dim self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True) self.norm = LayerNorm(input_dim, eps=1e-12) self.linear = Linear(input_dim, input_dim) self.droppath = DropPath(droppath_rate) # self.scale = nn.parameter.Parameter(torch.zeros([])) self.scale = 1.0 def forward(self, x: torch.Tensor): """forward Args: x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w Returns: features after adapter. The same shape as input. """ if x.shape[1] == 1: # for single frame, return itself. x = self.norm(x) x = einops.rearrange(x, "b t l c -> b (t l) c") return x shortcut = x x = einops.rearrange(x, "b t l c -> (b l) t c") x = self.norm(x) x = self.attn(x, x, x)[0] x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0]) x = shortcut + self.scale * self.droppath(x) x = einops.rearrange(x, "b t l c -> b (t l) c") return x class WindowTemporalAttention(nn.Module): """perform windowed temporal self-attention""" def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)): """ Kwargs: input_dim (int): The input feature dimension. """ super().__init__() self._input_dim = input_dim self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64) self.norm = LayerNorm(input_dim, eps=1e-12) self.droppath = DropPath(droppath_rate) self.scale = nn.parameter.Parameter(torch.zeros([])) self.wh, self.ww = window_size # logger.info(f"WindowTemporalAttention: window_size: {window_size}") def forward(self, x: torch.Tensor): """forward Args: x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w Returns: features after adapter. The same shape as input. """ if x.shape[1] == 1: # for single frame, return itself. return x shortcut = x h = w = int(math.sqrt(x.shape[2] - 1)) cls_token = x[:, :, :1, :] x = einops.rearrange( x[:, :, 1:, :], "b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c", nh=h // self.wh, wh=self.wh, nw=w // self.ww, ww=self.ww, ) x = self.norm(x) x = self.temporal_attn(x, x, x)[0] x = einops.rearrange( x, "(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c", wh=self.wh, ww=self.ww, nh=h // self.wh, nw=w // self.ww, ) # add back cls token. x = torch.concat([cls_token, x], dim=2) return shortcut + self.scale * self.droppath(x) class X_CLIP(nn.Module): """perform windowed temporal self-attention""" def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1): """ Kwargs: input_dim (int): The input feature dimension. """ super().__init__() d_model = input_dim self.message_fc = nn.Linear(d_model, d_model) self.message_ln = LayerNorm(d_model, eps=1e-12) self.message_attn = nn.MultiheadAttention(d_model, d_model // 64) self.num_prompts = num_prompts self.droppath = DropPath(droppath_rate) def forward(self, x: torch.Tensor): """forward Args: x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w Returns: features after adapter. The same shape as input. """ if x.shape[1] == 1: # for single frame, return itself. return x msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c] msg_token = rearrange(msg_token, "b t c -> t b c") msg_token = msg_token + self.droppath( self.message_attn(msg_token, msg_token, msg_token)[0] ) msg_token = rearrange(msg_token, "t b c -> b t c") # replace the last prompt token with msg_token. x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c] return x class TemporalS4(nn.Module): """perform temporal self-attention""" def __init__(self, input_dim=768, droppath_rate=0.1): """ Kwargs: input_dim (int): The input feature dimension. """ super().__init__() from .s4 import S4 self._input_dim = input_dim self.norm = LayerNorm(input_dim, eps=1e-12) self.droppath = DropPath(droppath_rate) self.scale = nn.parameter.Parameter(torch.zeros([])) self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True) def forward(self, x: torch.Tensor): """forward Args: x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w Returns: features after adapter. The same shape as input. """ if x.shape[1] == 1: # for single frame, return itself. return x shortcut = x x = self.norm(x) x = einops.rearrange(x, "b t l c -> b c (t l)") x, _ = self.s4(x) x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1]) return shortcut + self.scale * self.droppath(x)