176 lines
No EOL
7.2 KiB
Python
176 lines
No EOL
7.2 KiB
Python
'''
|
|
Differentiable approximation to the mutual information (MI) metric.
|
|
Implementation in PyTorch
|
|
'''
|
|
|
|
# Imports #
|
|
# ----------------------------------------------------------------------
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import cm
|
|
import os
|
|
|
|
# Note: This code snippet was taken from the discussion found at:
|
|
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
|
|
# By Tony-Y
|
|
class SoftHistogram1D(nn.Module):
|
|
'''
|
|
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
|
inupt:
|
|
x - N x D array, where N is the batch size and D is the length of each data series
|
|
bins - Number of bins for the histogram
|
|
min - Scalar min value to be included in the histogram
|
|
max - Scalar max value to be included in the histogram
|
|
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
|
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
|
output:
|
|
N x bins array, where each row is a histogram
|
|
'''
|
|
|
|
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
|
super(SoftHistogram1D, self).__init__()
|
|
self.bins = bins
|
|
self.min = min
|
|
self.max = max
|
|
self.sigma = sigma
|
|
self.delta = float(max - min) / float(bins)
|
|
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
|
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
|
|
|
def forward(self, x):
|
|
# Replicate x and for each row remove center
|
|
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
|
|
|
# Bin approximation using a sigmoid function
|
|
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
|
|
|
# Sum along the non-batch dimensions
|
|
x = x.sum(dim=-1)
|
|
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
|
|
return x
|
|
|
|
|
|
# Note: This is an extension to the 2D case of the previous code snippet
|
|
class SoftHistogram2D(nn.Module):
|
|
'''
|
|
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
|
inupt:
|
|
x, y - N x D array, where N is the batch size and D is the length of each data series
|
|
(i.e. vectorized image or vectorized 3D volume)
|
|
bins - Number of bins for the histogram
|
|
min - Scalar min value to be included in the histogram
|
|
max - Scalar max value to be included in the histogram
|
|
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
|
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
|
output:
|
|
N x bins array, where each row is a histogram
|
|
'''
|
|
|
|
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
|
super(SoftHistogram2D, self).__init__()
|
|
self.bins = bins
|
|
self.min = min
|
|
self.max = max
|
|
self.sigma = sigma
|
|
self.delta = float(max - min) / float(bins)
|
|
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
|
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
|
|
|
def forward(self, x, y):
|
|
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
|
|
|
|
# Replicate x and for each row remove center
|
|
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
|
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
|
|
|
|
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
|
|
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
|
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
|
|
|
|
# Batched matrix multiplication - this way we sum jointly
|
|
z = torch.matmul(x, y.permute((0, 2, 1)))
|
|
return z
|
|
|
|
|
|
class MI_pytorch(nn.Module):
|
|
'''
|
|
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
|
|
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
|
|
|
|
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
|
|
|
|
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
|
|
'''
|
|
|
|
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
|
|
super(MI_pytorch, self).__init__()
|
|
self.bins = bins
|
|
self.min = min
|
|
self.max = max
|
|
self.sigma = sigma
|
|
self.reduction = reduction
|
|
|
|
# 2D joint histogram
|
|
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
|
|
|
|
# Epsilon - to avoid log(0)
|
|
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
|
|
|
|
def forward(self, im1, im2):
|
|
'''
|
|
Forward implementation of a differentiable MI estimator for batched images
|
|
:param im1: N x ... tensor, where N is the batch size
|
|
... dimensions can take any form, i.e. 2D images or 3D volumes.
|
|
:param im2: N x ... tensor, where N is the batch size
|
|
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
|
|
'''
|
|
|
|
# Check for valid inputs
|
|
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
|
|
|
|
batch_size = im1.size()[0]
|
|
|
|
# Flatten tensors
|
|
im1_flat = im1.view(im1.size()[0], -1)
|
|
im2_flat = im2.view(im2.size()[0], -1)
|
|
|
|
# Calculate joint histogram
|
|
hgram = self.hist2d(im1_flat, im2_flat)
|
|
|
|
# Convert to a joint distribution
|
|
# Pxy = torch.distributions.Categorical(probs=hgram).probs
|
|
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
|
|
|
|
# Calculate the marginal distributions
|
|
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
|
|
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
|
|
|
|
# Use the KL divergence distance to calculate the MI
|
|
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
|
|
|
|
# Reshape to batch_size X all_the_rest
|
|
Pxy = Pxy.reshape(batch_size, -1)
|
|
Px_Py = Px_Py.reshape(batch_size, -1)
|
|
|
|
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
|
|
# but it can handle batches
|
|
if batch_size == 1:
|
|
# No need for eps approximation in the case of a single batch
|
|
nzs = Pxy > 0 # Calculate based on the non-zero values only
|
|
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
|
|
else:
|
|
# For arbitrary batch size > 1
|
|
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
|
|
|
|
# Reduction
|
|
if self.reduction == 'sum':
|
|
mut_info = torch.sum(mut_info)
|
|
elif self.reduction == 'batchmean':
|
|
mut_info = torch.sum(mut_info)
|
|
mut_info = mut_info / float(batch_size)
|
|
elif self.reduction=='individual':
|
|
pass
|
|
|
|
return mut_info |