first commit
This commit is contained in:
parent
99ce0acafb
commit
8f6b6a34e7
73 changed files with 11656 additions and 0 deletions
98
model/attended_hand_recognition.py
Normal file
98
model/attended_hand_recognition.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from model import graph_convolution_network
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class attended_hand_recognition(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.body_joint_number = opt.body_joint_number
|
||||
self.hand_joint_number = opt.hand_joint_number
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number
|
||||
self.input_n = opt.seq_len
|
||||
gcn_latent_features = opt.gcn_latent_features
|
||||
residual_gcns_num = opt.residual_gcns_num
|
||||
gcn_dropout = opt.gcn_dropout
|
||||
head_cnn_channels = opt.head_cnn_channels
|
||||
recognition_cnn_channels = opt.recognition_cnn_channels
|
||||
|
||||
# 1D CNN for extracting features from head directions
|
||||
in_channels_head = 3
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_head = head_cnn_channels
|
||||
out_channels_2_head = head_cnn_channels
|
||||
out_channels_head = head_cnn_channels
|
||||
|
||||
self.head_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_2_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# GCN for extracting features from body and left hand joints
|
||||
self.left_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
# GCN for extracting features from body and right hand joints
|
||||
self.right_hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
# 1D CNN for recognising attended hand (left or right)
|
||||
in_channels_recognition = self.joint_number*gcn_latent_features*2 + out_channels_head
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_recognition = recognition_cnn_channels
|
||||
out_channels_recognition = 2
|
||||
|
||||
self.recognition_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_recognition, out_channels=out_channels_1_recognition, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_recognition, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_recognition, out_channels=out_channels_recognition, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, src, input_n=15):
|
||||
|
||||
bs, seq_len, features = src.shape
|
||||
body_joints = src.clone()[:, :, :self.body_joint_number*3]
|
||||
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
|
||||
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
|
||||
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
|
||||
|
||||
left_hand_joints = torch.cat((left_hand_joints, body_joints), dim=2)
|
||||
left_hand_joints = left_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
left_hand_features = self.left_hand_gcn(left_hand_joints)
|
||||
left_hand_features = left_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
right_hand_joints = torch.cat((right_hand_joints, body_joints), dim=2)
|
||||
right_hand_joints = right_hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
right_hand_features = self.right_hand_gcn(right_hand_joints)
|
||||
right_hand_features = right_hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
head_direction = head_direction.permute(0,2,1)
|
||||
head_features = self.head_cnn(head_direction)
|
||||
|
||||
# fuse head and hand features
|
||||
features = torch.cat((left_hand_features, right_hand_features), dim=1)
|
||||
features = torch.cat((features, head_features), dim=1)
|
||||
# recognise attended hand from fused features
|
||||
prediction = self.recognition_cnn(features).permute(0, 2, 1)
|
||||
|
||||
return prediction
|
140
model/gaze_estimation.py
Normal file
140
model/gaze_estimation.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from model import graph_convolution_network, transformer
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class gaze_estimation(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.body_joint_number = opt.body_joint_number
|
||||
self.hand_joint_number = opt.hand_joint_number
|
||||
self.input_n = opt.seq_len
|
||||
self.object_num = opt.object_num
|
||||
gcn_latent_features = opt.gcn_latent_features
|
||||
residual_gcns_num = opt.residual_gcns_num
|
||||
gcn_dropout = opt.gcn_dropout
|
||||
head_cnn_channels = opt.head_cnn_channels
|
||||
gaze_cnn_channels = opt.gaze_cnn_channels
|
||||
self.use_self_att = opt.use_self_att
|
||||
self_att_head_num = opt.self_att_head_num
|
||||
self_att_dropout = opt.self_att_dropout
|
||||
self.use_cross_att = opt.use_cross_att
|
||||
cross_att_head_num = opt.cross_att_head_num
|
||||
cross_att_dropout = opt.cross_att_dropout
|
||||
self.use_attended_hand = opt.use_attended_hand
|
||||
self.use_attended_hand_gt = opt.use_attended_hand_gt
|
||||
if self.use_attended_hand:
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number + self.object_num
|
||||
else:
|
||||
self.joint_number = self.body_joint_number + self.hand_joint_number*2 + self.object_num*2
|
||||
|
||||
# 1D CNN for extracting features from head directions
|
||||
in_channels_head = 3
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_head = head_cnn_channels
|
||||
out_channels_2_head = head_cnn_channels
|
||||
out_channels_head = head_cnn_channels
|
||||
|
||||
self.head_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_head, out_channels=out_channels_1_head, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_head, out_channels=out_channels_2_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_2_head, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_2_head, out_channels=out_channels_head, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# GCN for extracting features from hand joints, body joints, and scene objects
|
||||
self.hand_gcn = graph_convolution_network.graph_convolution_network(in_features=3,
|
||||
latent_features=gcn_latent_features,
|
||||
node_n=self.joint_number,
|
||||
seq_len=self.input_n,
|
||||
p_dropout=gcn_dropout,
|
||||
residual_gcns_num=residual_gcns_num)
|
||||
|
||||
if self.use_self_att:
|
||||
self.head_self_att = transformer.temporal_self_attention(out_channels_head, self_att_head_num, self_att_dropout)
|
||||
self.hand_self_att = transformer.temporal_self_attention(self.joint_number*gcn_latent_features, self_att_head_num, self_att_dropout)
|
||||
|
||||
if self.use_cross_att:
|
||||
self.head_hand_cross_att = transformer.temporal_cross_attention(out_channels_head, self.joint_number*gcn_latent_features, cross_att_head_num, cross_att_dropout)
|
||||
self.hand_head_cross_att = transformer.temporal_cross_attention(self.joint_number*gcn_latent_features, out_channels_head, cross_att_head_num, cross_att_dropout)
|
||||
|
||||
# 1D CNN for estimating eye gaze
|
||||
in_channels_gaze = self.joint_number*gcn_latent_features + out_channels_head
|
||||
cnn_kernel_size = 3
|
||||
cnn_padding = (cnn_kernel_size -1)//2
|
||||
out_channels_1_gaze = gaze_cnn_channels
|
||||
out_channels_gaze = 3
|
||||
|
||||
self.gaze_cnn = nn.Sequential(
|
||||
nn.Conv1d(in_channels = in_channels_gaze, out_channels=out_channels_1_gaze, kernel_size=cnn_kernel_size, padding=cnn_padding, padding_mode='replicate'),
|
||||
nn.LayerNorm([out_channels_1_gaze, self.input_n]),
|
||||
nn.Tanh(),
|
||||
nn.Conv1d(in_channels=out_channels_1_gaze, out_channels=out_channels_gaze, kernel_size=cnn_kernel_size, padding = cnn_padding, padding_mode='replicate'),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
|
||||
def forward(self, src, input_n=15):
|
||||
|
||||
bs, seq_len, features = src.shape
|
||||
body_joints = src.clone()[:, :, :self.body_joint_number*3]
|
||||
left_hand_joints = src.clone()[:, :, self.body_joint_number*3:(self.body_joint_number+self.hand_joint_number)*3]
|
||||
right_hand_joints = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number)*3:(self.body_joint_number+self.hand_joint_number*2)*3]
|
||||
head_direction = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2)*3:(self.body_joint_number+self.hand_joint_number*2+1)*3]
|
||||
if self.object_num > 0:
|
||||
left_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3]
|
||||
left_object_position = torch.mean(left_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
|
||||
right_object_position = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3]
|
||||
right_object_position = torch.mean(right_object_position.reshape(bs, seq_len, self.object_num, 8, 3), dim=3).reshape(bs, seq_len, self.object_num*3)
|
||||
|
||||
attended_hand_prd = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2]
|
||||
left_hand_weights = torch.round(attended_hand_prd[:, :, 0:1])
|
||||
right_hand_weights = torch.round(attended_hand_prd[:, :, 1:2])
|
||||
if self.use_attended_hand_gt:
|
||||
attended_hand_gt = src.clone()[:, :, (self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+2:(self.body_joint_number+self.hand_joint_number*2+1+8*self.object_num*2)*3+3]
|
||||
left_hand_weights = 1-attended_hand_gt
|
||||
right_hand_weights = attended_hand_gt
|
||||
|
||||
if self.use_attended_hand:
|
||||
hand_joints = left_hand_joints*left_hand_weights + right_hand_joints*right_hand_weights
|
||||
else:
|
||||
hand_joints = torch.cat((left_hand_joints, right_hand_joints), dim=2)
|
||||
hand_joints = torch.cat((hand_joints, body_joints), dim=2)
|
||||
if self.object_num > 0:
|
||||
if self.use_attended_hand:
|
||||
object_position = left_object_position*left_hand_weights + right_object_position*right_hand_weights
|
||||
else:
|
||||
object_position = torch.cat((left_object_position, right_object_position), dim=2)
|
||||
hand_joints = torch.cat((hand_joints, object_position), dim=2)
|
||||
|
||||
hand_joints = hand_joints.permute(0, 2, 1).reshape(bs, -1, 3, input_n).permute(0, 2, 1, 3)
|
||||
hand_features = self.hand_gcn(hand_joints)
|
||||
hand_features = hand_features.permute(0, 2, 1, 3).reshape(bs, -1, input_n)
|
||||
|
||||
head_direction = head_direction.permute(0,2,1)
|
||||
head_features = self.head_cnn(head_direction)
|
||||
|
||||
if self.use_self_att:
|
||||
head_features = self.head_self_att(head_features.permute(0,2,1)).permute(0,2,1)
|
||||
hand_features = self.hand_self_att(hand_features.permute(0,2,1)).permute(0,2,1)
|
||||
|
||||
if self.use_cross_att:
|
||||
head_features_copy = head_features.clone()
|
||||
head_features = self.head_hand_cross_att(head_features.permute(0,2,1), hand_features.permute(0,2,1)).permute(0,2,1)
|
||||
hand_features = self.hand_head_cross_att(hand_features.permute(0,2,1), head_features_copy.permute(0,2,1)).permute(0,2,1)
|
||||
|
||||
# fuse head and hand features
|
||||
features = torch.cat((hand_features, head_features), dim=1)
|
||||
# estimate eye gaze
|
||||
prediction = self.gaze_cnn(features).permute(0, 2, 1)
|
||||
# normalize to unit vectors
|
||||
prediction = F.normalize(prediction, dim=2)
|
||||
|
||||
return prediction
|
79
model/graph_convolution_network.py
Normal file
79
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
import math
|
||||
|
||||
|
||||
class graph_convolution(nn.Module):
|
||||
def __init__(self, in_features, out_features, node_n = 21, seq_len = 40, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, input):
|
||||
y = torch.matmul(input, self.temporal_graph_weights)
|
||||
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||
|
||||
if self.bias is not None:
|
||||
return (y + self.bias)
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
class residual_graph_convolution(nn.Module):
|
||||
def __init__(self, features, node_n=21, seq_len = 40, bias=True, p_dropout=0.3):
|
||||
super().__init__()
|
||||
|
||||
self.gcn = graph_convolution(features, features, node_n=node_n, seq_len=seq_len, bias=bias)
|
||||
self.ln = nn.LayerNorm([features, node_n, seq_len], elementwise_affine=True)
|
||||
self.act_f = nn.Tanh()
|
||||
self.dropout = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y = self.gcn(x)
|
||||
y = self.ln(y)
|
||||
y = self.act_f(y)
|
||||
y = self.dropout(y)
|
||||
|
||||
return y + x
|
||||
|
||||
|
||||
class graph_convolution_network(nn.Module):
|
||||
def __init__(self, in_features, latent_features, node_n=21, seq_len=40, p_dropout=0.3, residual_gcns_num=1):
|
||||
super().__init__()
|
||||
self.residual_gcns_num = residual_gcns_num
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.start_gcn = graph_convolution(in_features=in_features, out_features=latent_features, node_n=node_n, seq_len=seq_len)
|
||||
|
||||
self.residual_gcns = []
|
||||
for i in range(residual_gcns_num):
|
||||
self.residual_gcns.append(residual_graph_convolution(features=latent_features, node_n=node_n, seq_len=seq_len*2, p_dropout=p_dropout))
|
||||
self.residual_gcns = nn.ModuleList(self.residual_gcns)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.start_gcn(x)
|
||||
|
||||
y = torch.cat((y, y), dim=3)
|
||||
for i in range(self.residual_gcns_num):
|
||||
y = self.residual_gcns[i](y)
|
||||
y = y[:, :, :, :self.seq_len]
|
||||
|
||||
return y
|
138
model/transformer.py
Normal file
138
model/transformer.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import layer_norm, nn
|
||||
import math
|
||||
|
||||
|
||||
class temporal_self_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, T, D
|
||||
"""
|
||||
B, T, D = x.shape
|
||||
H = self.num_head
|
||||
# B, T, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, T, D
|
||||
key = self.key(self.norm(x)).unsqueeze(1)
|
||||
query = query.view(B, T, H, -1)
|
||||
key = key.view(B, T, H, -1)
|
||||
# B, T, T, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.norm(x)).view(B, T, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class spatial_self_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, S, D
|
||||
"""
|
||||
B, S, D = x.shape
|
||||
H = self.num_head
|
||||
# B, S, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, S, D
|
||||
key = self.key(self.norm(x)).unsqueeze(1)
|
||||
query = query.view(B, S, H, -1)
|
||||
key = key.view(B, S, H, -1)
|
||||
# B, S, S, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.norm(x)).view(B, S, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class temporal_cross_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, mod_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.mod_norm = nn.LayerNorm(mod_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, xf):
|
||||
"""
|
||||
x: B, T, D
|
||||
xf: B, N, L
|
||||
"""
|
||||
B, T, D = x.shape
|
||||
N = xf.shape[1]
|
||||
H = self.num_head
|
||||
# B, T, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, N, D
|
||||
key = self.key(self.mod_norm(xf)).unsqueeze(1)
|
||||
query = query.view(B, T, H, -1)
|
||||
key = key.view(B, N, H, -1)
|
||||
# B, T, N, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
|
||||
y = x + y
|
||||
return y
|
||||
|
||||
|
||||
class spatial_cross_attention(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, mod_dim, num_head, dropout):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.norm = nn.LayerNorm(latent_dim)
|
||||
self.mod_norm = nn.LayerNorm(mod_dim)
|
||||
self.query = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.key = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.value = nn.Linear(mod_dim, latent_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, xf):
|
||||
"""
|
||||
x: B, S, D
|
||||
xf: B, N, L
|
||||
"""
|
||||
B, S, D = x.shape
|
||||
N = xf.shape[1]
|
||||
H = self.num_head
|
||||
# B, S, 1, D
|
||||
query = self.query(self.norm(x)).unsqueeze(2)
|
||||
# B, 1, N, D
|
||||
key = self.key(self.mod_norm(xf)).unsqueeze(1)
|
||||
query = query.view(B, S, H, -1)
|
||||
key = key.view(B, N, H, -1)
|
||||
# B, S, N, H
|
||||
attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
|
||||
weight = self.dropout(F.softmax(attention, dim=2))
|
||||
value = self.value(self.mod_norm(xf)).view(B, N, H, -1)
|
||||
y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, S, D)
|
||||
y = x + y
|
||||
return y
|
Loading…
Add table
Add a link
Reference in a new issue