initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

View file

View file

@ -0,0 +1,286 @@
import logging
import math
import einops
import torch
from einops import rearrange
from timm.models.layers.drop import DropPath
from torch import nn
from torch.nn import LayerNorm, Linear, MultiheadAttention
logger = logging.getLogger(__name__)
class STAdapter(nn.Module):
"""ST Adapter"""
def __init__(
self,
kernel_size=(3, 3, 3),
input_dim=768,
hidden_dim=384,
img_size=224,
patch_size=16,
drop_prob=0.1,
):
super(STAdapter, self).__init__()
self.kernel_size = kernel_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.h = self.w = img_size // patch_size
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, input_dim)
self.act = nn.ReLU()
self.conv = nn.Conv3d(
hidden_dim, hidden_dim, kernel_size=kernel_size, padding="same", groups=hidden_dim
)
self.droppath = DropPath(drop_prob=drop_prob)
self.scale = nn.parameter.Parameter(torch.zeros([]))
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
x = self.linear1(x)
cls = x[:, :, :1, :]
tokens = x[:, :, 1:, :]
tokens = einops.rearrange(tokens, "b t (h w) c -> b c t h w", h=self.h).contiguous()
tokens = self.conv(tokens)
tokens = einops.rearrange(tokens, "b c t h w -> b t (h w) c")
x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c]
x = self.act(x)
x = self.linear2(x)
return shortcut + self.scale * self.droppath(x)
class SpatialAttention(nn.Module):
"""Perfrom spatial self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
super(SpatialAttention, self).__init__()
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.linear = Linear(input_dim, input_dim)
self.droppath = DropPath(droppath_rate)
# self.scale = nn.parameter.Parameter(torch.zeros([]))
self.scale = 1.0
def forward(self, x: torch.Tensor):
if x.shape[1] == 1:
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x # return self if media is image
shortcut = x
x = einops.rearrange(x, 'b t l c -> (b t) l c')
x = self.norm(x)
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, "(b t) l c -> b t l c", b=shortcut.shape[0])
x = shortcut + self.scale * self.droppath(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
class TemporalAttention(nn.Module):
"""perform temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super(TemporalAttention, self).__init__()
self._input_dim = input_dim
self.attn = MultiheadAttention(input_dim, num_heads=input_dim // 64, batch_first=True)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.linear = Linear(input_dim, input_dim)
self.droppath = DropPath(droppath_rate)
# self.scale = nn.parameter.Parameter(torch.zeros([]))
self.scale = 1.0
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
shortcut = x
x = einops.rearrange(x, "b t l c -> (b l) t c")
x = self.norm(x)
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, "(b l) t c -> b t l c", b=shortcut.shape[0])
x = shortcut + self.scale * self.droppath(x)
x = einops.rearrange(x, "b t l c -> b (t l) c")
return x
class WindowTemporalAttention(nn.Module):
"""perform windowed temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
self._input_dim = input_dim
self.temporal_attn = MultiheadAttention(input_dim, num_heads=input_dim // 64)
self.norm = LayerNorm(input_dim, eps=1e-12)
self.droppath = DropPath(droppath_rate)
self.scale = nn.parameter.Parameter(torch.zeros([]))
self.wh, self.ww = window_size
# logger.info(f"WindowTemporalAttention: window_size: {window_size}")
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
h = w = int(math.sqrt(x.shape[2] - 1))
cls_token = x[:, :, :1, :]
x = einops.rearrange(
x[:, :, 1:, :],
"b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c",
nh=h // self.wh,
wh=self.wh,
nw=w // self.ww,
ww=self.ww,
)
x = self.norm(x)
x = self.temporal_attn(x, x, x)[0]
x = einops.rearrange(
x,
"(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c",
wh=self.wh,
ww=self.ww,
nh=h // self.wh,
nw=w // self.ww,
)
# add back cls token.
x = torch.concat([cls_token, x], dim=2)
return shortcut + self.scale * self.droppath(x)
class X_CLIP(nn.Module):
"""perform windowed temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
d_model = input_dim
self.message_fc = nn.Linear(d_model, d_model)
self.message_ln = LayerNorm(d_model, eps=1e-12)
self.message_attn = nn.MultiheadAttention(d_model, d_model // 64)
self.num_prompts = num_prompts
self.droppath = DropPath(droppath_rate)
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
msg_token = self.message_ln(self.message_fc(x[:, :, 0, :])) # [b, t, c]
msg_token = rearrange(msg_token, "b t c -> t b c")
msg_token = msg_token + self.droppath(
self.message_attn(msg_token, msg_token, msg_token)[0]
)
msg_token = rearrange(msg_token, "t b c -> b t c")
# replace the last prompt token with msg_token.
x = torch.cat([x[:, :, :-1, :], msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c]
return x
class TemporalS4(nn.Module):
"""perform temporal self-attention"""
def __init__(self, input_dim=768, droppath_rate=0.1):
"""
Kwargs:
input_dim (int): The input feature dimension.
"""
super().__init__()
from .s4 import S4
self._input_dim = input_dim
self.norm = LayerNorm(input_dim, eps=1e-12)
self.droppath = DropPath(droppath_rate)
self.scale = nn.parameter.Parameter(torch.zeros([]))
self.s4 = S4(d_model=input_dim, bidirectional=True, transposed=True)
def forward(self, x: torch.Tensor):
"""forward
Args:
x (torch.Tensor): input features. Shape: [bs, nframes, l, c]. l = 1 + h*w
Returns: features after adapter. The same shape as input.
"""
if x.shape[1] == 1: # for single frame, return itself.
return x
shortcut = x
x = self.norm(x)
x = einops.rearrange(x, "b t l c -> b c (t l)")
x, _ = self.s4(x)
x = einops.rearrange(x, "b c (t l) -> b t l c", t=shortcut.shape[1])
return shortcut + self.scale * self.droppath(x)