import torch import torch.nn as nn import torchvision.models as models from .utils import left_bias, right_bias class ResNet(nn.Module): def __init__(self, input_dim, device): super(ResNet, self).__init__() # Conv resnet = models.resnet34(pretrained=True) self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) # FFs self.left = nn.Linear(input_dim, 27) self.right = nn.Linear(input_dim, 27) # modality FFs self.pose_ff = nn.Linear(150, 150) self.gaze_ff = nn.Linear(6, 6) self.bbox_ff = nn.Linear(108, 108) self.ocr_ff = nn.Linear(729, 64) # others self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.2) self.device = device def forward(self, images, poses, gazes, bboxes, ocr_tensor): batch_size, sequence_len, channels, height, width = images.shape left_beliefs = [] right_beliefs = [] image_feats = [] for i in range(sequence_len): images_i = images[:,i].to(self.device) image_i_feat = self.resnet(images_i) image_i_feat = image_i_feat.view(batch_size, 512) if poses is not None: poses_i = poses[:,i].float() poses_i_feat = self.relu(self.pose_ff(poses_i)) image_i_feat = torch.cat([image_i_feat, poses_i_feat], 1) if gazes is not None: gazes_i = gazes[:,i].float() gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) image_i_feat = torch.cat([image_i_feat, gazes_i_feat], 1) if bboxes is not None: bboxes_i = bboxes[:,i].float() bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) image_i_feat = torch.cat([image_i_feat, bboxes_i_feat], 1) if ocr_tensor is not None: ocr_tensor = ocr_tensor ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) image_i_feat = torch.cat([image_i_feat, ocr_tensor_feat], 1) image_feats.append(image_i_feat) image_feats = torch.permute(torch.stack(image_feats), (1,0,2)) left_beliefs = self.left(self.dropout(image_feats)) right_beliefs = self.right(self.dropout(image_feats)) return left_beliefs, right_beliefs, None class ResNetGRU(nn.Module): def __init__(self, input_dim, device): super(ResNetGRU, self).__init__() resnet = models.resnet34(pretrained=True) self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) self.gru = nn.GRU(input_dim, 512, batch_first=True) for name, param in self.gru.named_parameters(): if "weight" in name: nn.init.orthogonal_(param) elif "bias" in name: nn.init.constant_(param, 0) # FFs self.left = nn.Linear(512, 27) self.right = nn.Linear(512, 27) # modality FFs self.pose_ff = nn.Linear(150, 150) self.gaze_ff = nn.Linear(6, 6) self.bbox_ff = nn.Linear(108, 108) self.ocr_ff = nn.Linear(729, 64) # others self.dropout = nn.Dropout(p=0.2) self.relu = nn.ReLU() self.device = device def forward(self, images, poses, gazes, bboxes, ocr_tensor): batch_size, sequence_len, channels, height, width = images.shape left_beliefs = [] right_beliefs = [] rnn_inp = [] for i in range(sequence_len): images_i = images[:,i] rnn_i_feat = self.resnet(images_i) rnn_i_feat = rnn_i_feat.view(batch_size, 512) if poses is not None: poses_i = poses[:,i].float() poses_i_feat = self.relu(self.pose_ff(poses_i)) rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1) if gazes is not None: gazes_i = gazes[:,i].float() gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1) if bboxes is not None: bboxes_i = bboxes[:,i].float() bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1) if ocr_tensor is not None: ocr_tensor = ocr_tensor ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1) rnn_inp.append(rnn_i_feat) rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2)) rnn_out, _ = self.gru(rnn_inp) left_beliefs = self.left(self.dropout(rnn_out)) right_beliefs = self.right(self.dropout(rnn_out)) return left_beliefs, right_beliefs, None class ResNetConv1D(nn.Module): def __init__(self, input_dim, device): super(ResNetConv1D, self).__init__() resnet = models.resnet34(pretrained=True) self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=5, padding=4) # FFs self.left = nn.Linear(512, 27) self.right = nn.Linear(512, 27) # modality FFs self.pose_ff = nn.Linear(150, 150) self.gaze_ff = nn.Linear(6, 6) self.bbox_ff = nn.Linear(108, 108) self.ocr_ff = nn.Linear(729, 64) # others self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.2) self.device = device def forward(self, images, poses, gazes, bboxes, ocr_tensor): batch_size, sequence_len, channels, height, width = images.shape left_beliefs = [] right_beliefs = [] conv1d_inp = [] for i in range(sequence_len): images_i = images[:,i] images_i_feat = self.resnet(images_i) images_i_feat = images_i_feat.view(batch_size, 512) if poses is not None: poses_i = poses[:,i].float() poses_i_feat = self.relu(self.pose_ff(poses_i)) images_i_feat = torch.cat([images_i_feat, poses_i_feat], 1) if gazes is not None: gazes_i = gazes[:,i].float() gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) images_i_feat = torch.cat([images_i_feat, gazes_i_feat], 1) if bboxes is not None: bboxes_i = bboxes[:,i].float() bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) images_i_feat = torch.cat([images_i_feat, bboxes_i_feat], 1) if ocr_tensor is not None: ocr_tensor = ocr_tensor ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) images_i_feat = torch.cat([images_i_feat, ocr_tensor_feat], 1) conv1d_inp.append(images_i_feat) conv1d_inp = torch.permute(torch.stack(conv1d_inp), (1,2,0)) conv1d_out = self.conv1d(conv1d_inp) conv1d_out = conv1d_out[:,:,:-4] conv1d_out = self.relu(torch.permute(conv1d_out, (0,2,1))) left_beliefs = self.left(self.dropout(conv1d_out)) right_beliefs = self.right(self.dropout(conv1d_out)) return left_beliefs, right_beliefs, None class ResNetLSTM(nn.Module): def __init__(self, input_dim, device): super(ResNetLSTM, self).__init__() resnet = models.resnet34(pretrained=True) self.resnet = nn.Sequential(*(list(resnet.children())[:-1])) self.lstm = nn.LSTM(input_size=input_dim, hidden_size=512, batch_first=True) # FFs self.left = nn.Linear(512, 27) self.right = nn.Linear(512, 27) self.left.bias.data = torch.tensor(left_bias).log() self.right.bias.data = torch.tensor(right_bias).log() # modality FFs self.pose_ff = nn.Linear(150, 150) self.gaze_ff = nn.Linear(6, 6) self.bbox_ff = nn.Linear(108, 108) self.ocr_ff = nn.Linear(729, 64) # others self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.2) self.device = device def forward(self, images, poses, gazes, bboxes, ocr_tensor): batch_size, sequence_len, channels, height, width = images.shape left_beliefs = [] right_beliefs = [] rnn_inp = [] for i in range(sequence_len): images_i = images[:,i] rnn_i_feat = self.resnet(images_i) rnn_i_feat = rnn_i_feat.view(batch_size, 512) if poses is not None: poses_i = poses[:,i].float() poses_i_feat = self.relu(self.pose_ff(poses_i)) rnn_i_feat = torch.cat([rnn_i_feat, poses_i_feat], 1) if gazes is not None: gazes_i = gazes[:,i].float() gazes_i_feat = self.relu(self.gaze_ff(gazes_i)) rnn_i_feat = torch.cat([rnn_i_feat, gazes_i_feat], 1) if bboxes is not None: bboxes_i = bboxes[:,i].float() bboxes_i_feat = self.relu(self.bbox_ff(bboxes_i)) rnn_i_feat = torch.cat([rnn_i_feat, bboxes_i_feat], 1) if ocr_tensor is not None: ocr_tensor = ocr_tensor ocr_tensor_feat = self.relu(self.ocr_ff(ocr_tensor)) rnn_i_feat = torch.cat([rnn_i_feat, ocr_tensor_feat], 1) rnn_inp.append(rnn_i_feat) rnn_inp = torch.permute(torch.stack(rnn_inp), (1,0,2)) rnn_out, _ = self.lstm(rnn_inp) left_beliefs = self.left(self.dropout(rnn_out)) right_beliefs = self.right(self.dropout(rnn_out)) return left_beliefs, right_beliefs, None if __name__ == '__main__': images = torch.ones(3, 22, 3, 128, 128) poses = torch.ones(3, 22, 150) gazes = torch.ones(3, 22, 6) bboxes = torch.ones(3, 22, 108) model = ResNet(32, 'cpu') print(model(images, poses, gazes, bboxes, None)[0].shape)