import torch import torch.nn as nn class Norm(nn.Module): def __init__(self, norm_type, hidden_dim=64, print_info=None): super(Norm, self).__init__() # assert norm_type in ['bn', 'ln', 'gn', None] self.norm = None self.print_info = print_info if norm_type == 'bn': self.norm = nn.BatchNorm1d(hidden_dim) elif norm_type == 'gn': self.norm = norm_type self.weight = nn.Parameter(torch.ones(hidden_dim)) self.bias = nn.Parameter(torch.zeros(hidden_dim)) self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) def forward(self, graph, tensor, print_=False): if self.norm is not None and type(self.norm) != str: return self.norm(tensor) elif self.norm is None: return tensor batch_list = graph.batch_num_nodes('obj') batch_size = len(batch_list) #batch_list = torch.tensor(batch_list).long().to(tensor.device) batch_list = batch_list.long().to(tensor.device) batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) mean = mean.scatter_add_(0, batch_index, tensor) mean = (mean.T / batch_list).T mean = mean.repeat_interleave(batch_list, dim=0) sub = tensor - mean * self.mean_scale std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) std = std.scatter_add_(0, batch_index, sub.pow(2)) std = ((std.T / batch_list).T + 1e-6).sqrt() std = std.repeat_interleave(batch_list, dim=0) return self.weight * sub / std + self.bias