46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class Norm(nn.Module):
|
||
|
|
||
|
def __init__(self, norm_type, hidden_dim=64, print_info=None):
|
||
|
super(Norm, self).__init__()
|
||
|
|
||
|
# assert norm_type in ['bn', 'ln', 'gn', None]
|
||
|
self.norm = None
|
||
|
self.print_info = print_info
|
||
|
if norm_type == 'bn':
|
||
|
self.norm = nn.BatchNorm1d(hidden_dim)
|
||
|
elif norm_type == 'gn':
|
||
|
self.norm = norm_type
|
||
|
self.weight = nn.Parameter(torch.ones(hidden_dim))
|
||
|
self.bias = nn.Parameter(torch.zeros(hidden_dim))
|
||
|
|
||
|
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
|
||
|
|
||
|
def forward(self, graph, tensor, print_=False):
|
||
|
|
||
|
if self.norm is not None and type(self.norm) != str:
|
||
|
return self.norm(tensor)
|
||
|
elif self.norm is None:
|
||
|
return tensor
|
||
|
|
||
|
batch_list = graph.batch_num_nodes('obj')
|
||
|
batch_size = len(batch_list)
|
||
|
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
|
||
|
batch_list = batch_list.long().to(tensor.device)
|
||
|
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
|
||
|
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
|
||
|
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||
|
mean = mean.scatter_add_(0, batch_index, tensor)
|
||
|
mean = (mean.T / batch_list).T
|
||
|
mean = mean.repeat_interleave(batch_list, dim=0)
|
||
|
|
||
|
sub = tensor - mean * self.mean_scale
|
||
|
|
||
|
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||
|
std = std.scatter_add_(0, batch_index, sub.pow(2))
|
||
|
std = ((std.T / batch_list).T + 1e-6).sqrt()
|
||
|
std = std.repeat_interleave(batch_list, dim=0)
|
||
|
return self.weight * sub / std + self.bias
|