DisMouse/model/MI.py

176 lines
7.2 KiB
Python
Raw Permalink Normal View History

2024-10-08 14:18:47 +02:00
'''
Differentiable approximation to the mutual information (MI) metric.
Implementation in PyTorch
'''
# Imports #
# ----------------------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import os
# Note: This code snippet was taken from the discussion found at:
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
# By Tony-Y
class SoftHistogram1D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x - N x D array, where N is the batch size and D is the length of each data series
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram1D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x):
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
# Sum along the non-batch dimensions
x = x.sum(dim=-1)
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
return x
# Note: This is an extension to the 2D case of the previous code snippet
class SoftHistogram2D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x, y - N x D array, where N is the batch size and D is the length of each data series
(i.e. vectorized image or vectorized 3D volume)
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram2D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x, y):
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
# Batched matrix multiplication - this way we sum jointly
z = torch.matmul(x, y.permute((0, 2, 1)))
return z
class MI_pytorch(nn.Module):
'''
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
'''
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
super(MI_pytorch, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.reduction = reduction
# 2D joint histogram
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
# Epsilon - to avoid log(0)
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
def forward(self, im1, im2):
'''
Forward implementation of a differentiable MI estimator for batched images
:param im1: N x ... tensor, where N is the batch size
... dimensions can take any form, i.e. 2D images or 3D volumes.
:param im2: N x ... tensor, where N is the batch size
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
'''
# Check for valid inputs
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
batch_size = im1.size()[0]
# Flatten tensors
im1_flat = im1.view(im1.size()[0], -1)
im2_flat = im2.view(im2.size()[0], -1)
# Calculate joint histogram
hgram = self.hist2d(im1_flat, im2_flat)
# Convert to a joint distribution
# Pxy = torch.distributions.Categorical(probs=hgram).probs
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
# Calculate the marginal distributions
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
# Use the KL divergence distance to calculate the MI
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
# Reshape to batch_size X all_the_rest
Pxy = Pxy.reshape(batch_size, -1)
Px_Py = Px_Py.reshape(batch_size, -1)
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
# but it can handle batches
if batch_size == 1:
# No need for eps approximation in the case of a single batch
nzs = Pxy > 0 # Calculate based on the non-zero values only
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
else:
# For arbitrary batch size > 1
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
# Reduction
if self.reduction == 'sum':
mut_info = torch.sum(mut_info)
elif self.reduction == 'batchmean':
mut_info = torch.sum(mut_info)
mut_info = mut_info / float(batch_size)
elif self.reduction=='individual':
pass
return mut_info