SalChartQA/Code/model_swin.py
2024-01-22 21:02:16 +08:00

116 lines
4.5 KiB
Python

import torch
class SalFormer(torch.nn.Module):
def __init__(self, vision_encoder, bert):
"""
In the constructor we instantiate four parameters and assign them as
member parameters.
"""
super().__init__()
self.vit = vision_encoder
self.feature_dim = 768
self.bert = bert
self.vision_head = torch.nn.Sequential(
torch.nn.Linear(self.feature_dim, self.feature_dim),
torch.nn.GELU(),
torch.nn.Linear(self.feature_dim, self.feature_dim),
torch.nn.GELU()
)
self.text_head = torch.nn.Sequential(
torch.nn.Linear(self.feature_dim, self.feature_dim),
torch.nn.GELU(),
torch.nn.Linear(self.feature_dim, self.feature_dim),
torch.nn.GELU()
)
self.cross_attention = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True)
self.cross_attention1 = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True)
self.ln1 = torch.nn.LayerNorm(self.feature_dim)
self.ln2 = torch.nn.LayerNorm(self.feature_dim)
self.self_attetion = torch.nn.MultiheadAttention(self.feature_dim, 16, batch_first=True)
self.text_feature_query = torch.nn.Parameter(torch.randn(10, self.feature_dim).unsqueeze(0)/2)
self.img_positional_embedding = torch.nn.Parameter(torch.zeros(49, self.feature_dim))
self.text_positional_embedding = torch.nn.Parameter(torch.zeros(10, self.feature_dim))
self.dense1 = torch.nn.Linear(self.feature_dim, self.feature_dim)
self.relu1 = torch.nn.ReLU()
self.decoder = torch.nn.Sequential(
torch.nn.Conv2d(self.feature_dim, 512, 3),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Conv2d(512, 512, 3),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Upsample((16,16), mode='bilinear'),
torch.nn.Conv2d(512, 256, 3),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Conv2d(256, 256, 3),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Upsample((32,32), mode='bilinear'),
torch.nn.Conv2d(256, 128, 3),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Conv2d(128, 128, 3),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.1),
torch.nn.Upsample((130,130), mode='bilinear'),
torch.nn.Conv2d(128, 1, 3),
torch.nn.BatchNorm2d(1),
torch.nn.Sigmoid(),
)
self.vit.eval()
self.bert.eval()
self.train(True)
# def eval(self):
# super().eval()
# self.vit.eval()
# self.bert.eval()
# def train(self, mode=True):
# super().train(mode)
# self.vit.train(mode)
# self.bert.train(mode)
def forward(self, img, q_inputs):
img_features = self.vit.forward(img, return_dict =True)["last_hidden_state"]
with torch.no_grad():
text_features = self.bert(**q_inputs)["last_hidden_state"]
# text_features = torch.unsqueeze(bert_output["last_hidden_state"][:,0,:], 1)
text_features = self.cross_attention.forward(self.text_feature_query.repeat([text_features.shape[0], 1, 1]), text_features, text_features, need_weights=False)[0]
fused_features = torch.concat((self.vision_head(img_features)+self.img_positional_embedding, self.text_head(text_features)+self.text_positional_embedding), 1)
att_fused_features = self.self_attetion.forward(fused_features, fused_features, fused_features, need_weights=False)[0]
fused_features = fused_features + att_fused_features
fused_features = self.ln1(fused_features)
features = self.cross_attention1.forward(img_features, fused_features, fused_features, need_weights=False)[0]
features = img_features + features
features = self.ln2(features)
features = self.dense1(features)
latent_features = self.relu1(features)
latent_features = latent_features.permute(0,2,1)
out = torch.reshape(latent_features, (features.shape[0], self.feature_dim, 7, 7))
out = self.decoder(out)
return out