IRENE/tom/norm.py
2024-02-01 15:40:47 +01:00

46 lines
No EOL
1.8 KiB
Python

import torch
import torch.nn as nn
class Norm(nn.Module):
def __init__(self, norm_type, hidden_dim=64, print_info=None):
super(Norm, self).__init__()
# assert norm_type in ['bn', 'ln', 'gn', None]
self.norm = None
self.print_info = print_info
if norm_type == 'bn':
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == 'gn':
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
def forward(self, graph, tensor, print_=False):
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes('obj')
batch_size = len(batch_list)
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
batch_list = batch_list.long().to(tensor.device)
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias