HOIGaze/model/transformer.py
2025-04-30 14:15:00 +02:00

138 lines
No EOL
4.9 KiB
Python

import torch
import torch.nn.functional as F
from torch import layer_norm, nn
import math
class temporal_self_attention(nn.Module):
def __init__(self, latent_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: B, T, D
"""
B, T, D = x.shape
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, T, D
key = self.key(self.norm(x)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, T, H, -1)
# B, T, T, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.norm(x)).view(B, T, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
y = x + y
return y
class spatial_self_attention(nn.Module):
def __init__(self, latent_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: B, S, D
"""
B, S, D = x.shape
H = self.num_head
# B, S, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, S, D
key = self.key(self.norm(x)).unsqueeze(1)
query = query.view(B, S, H, -1)
key = key.view(B, S, H, -1)
# B, S, S, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.norm(x)).view(B, S, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
y = x + y
return y
class temporal_cross_attention(nn.Module):
def __init__(self, latent_dim, mod_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.mod_norm = nn.LayerNorm(mod_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, xf):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, T, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.mod_norm(xf)).unsqueeze(1)
query = query.view(B, T, H, -1)
key = key.view(B, N, H, -1)
# B, T, N, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
y = x + y
return y
class spatial_cross_attention(nn.Module):
def __init__(self, latent_dim, mod_dim, num_head, dropout):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.mod_norm = nn.LayerNorm(mod_dim)
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, xf):
"""
x: B, S, D
xf: B, N, L
"""
B, S, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, S, 1, D
query = self.query(self.norm(x)).unsqueeze(2)
# B, 1, N, D
key = self.key(self.mod_norm(xf)).unsqueeze(1)
query = query.view(B, S, H, -1)
key = key.view(B, N, H, -1)
# B, S, N, H
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
weight = self.dropout(F.softmax(attention, dim=2))
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
y = x + y
return y