mtomnet/boss/models/resnet.py

250 lines
9.7 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import torch
import torch.nn as nn
import torchvision.models as models
from .utils import left_bias, right_bias
class ResNet(nn.Module):
def __init__(self, input_dim, device):
super(ResNet, self).__init__()
# Conv
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
# FFs
self.left = nn.Linear(input_dim, 27)
self.right = nn.Linear(input_dim, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
image_feats = []
for i in range(sequence_len):
images_i = images[:,i].to(self.device)
image_i_feat = self.resnet(images_i)
image_i_feat = image_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1)
image_feats.append(image_i_feat)
image_feats = torch.permute(torch.stack(image_feats), (1,0,2))
left_beliefs = self.left(self.dropout(image_feats))
right_beliefs = self.right(self.dropout(image_feats))
return left_beliefs, right_beliefs, None
class ResNetGRU(nn.Module):
def __init__(self, input_dim, device):
super(ResNetGRU, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.gru = nn.GRU(input_dim, 512, batch_first=True)
for name, param in self.gru.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.constant_(param, 0)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.dropout = nn.Dropout(p=0.2)
self.relu = nn.ReLU()
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.gru(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
class ResNetConv1D(nn.Module):
def __init__(self, input_dim, device):
super(ResNetConv1D, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
conv1d_inp = []
for i in range(sequence_len):
images_i = images[:,i]
images_i_feat = self.resnet(images_i)
images_i_feat = images_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1)
conv1d_inp.append(images_i_feat)
conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0))
conv1d_out = self.conv1d(conv1d_inp)
conv1d_out = conv1d_out[:,:,:-4]
conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1)))
left_beliefs = self.left(self.dropout(conv1d_out))
right_beliefs = self.right(self.dropout(conv1d_out))
return left_beliefs, right_beliefs, None
class ResNetLSTM(nn.Module):
def __init__(self, input_dim, device):
super(ResNetLSTM, self).__init__()
resnet = models.resnet34(pretrained=True)
self.resnet = nn.Sequential(*(list(resnet.children())[:-1]))
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True)
# FFs
self.left = nn.Linear(512, 27)
self.right = nn.Linear(512, 27)
self.left.bias.data = torch.tensor(left_bias).log()
self.right.bias.data = torch.tensor(right_bias).log()
# modality FFs
self.pose_ff = nn.Linear(150, 150)
self.gaze_ff = nn.Linear(6, 6)
self.bbox_ff = nn.Linear(108, 108)
self.ocr_ff = nn.Linear(729, 64)
# others
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.device = device
def forward(self, images, poses, gazes, bboxes, ocr_tensor):
batch_size, sequence_len, channels, height, width = images.shape
left_beliefs = []
right_beliefs = []
rnn_inp = []
for i in range(sequence_len):
images_i = images[:,i]
rnn_i_feat = self.resnet(images_i)
rnn_i_feat = rnn_i_feat.view(batch_size, 512)
if poses is not None:
poses_i = poses[:,i].float()
poses_i_feat = self.relu(self.pose_ff(poses_i))
rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1)
if gazes is not None:
gazes_i = gazes[:,i].float()
gazes_i_feat = self.relu(self.gaze_ff(gazes_i))
rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1)
if bboxes is not None:
bboxes_i = bboxes[:,i].float()
bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i))
rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1)
if ocr_tensor is not None:
ocr_tensor = ocr_tensor
ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor))
rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1)
rnn_inp.append(rnn_i_feat)
rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2))
rnn_out, _ = self.lstm(rnn_inp)
left_beliefs = self.left(self.dropout(rnn_out))
right_beliefs = self.right(self.dropout(rnn_out))
return left_beliefs, right_beliefs, None
if __name__ == '__main__':
images = torch.ones(3, 22, 3, 128, 128)
poses = torch.ones(3, 22, 150)
gazes = torch.ones(3, 22, 6)
bboxes = torch.ones(3, 22, 108)
model = ResNet(32, 'cpu')
print(model(images, poses, gazes, bboxes, None)[0].shape)