VDGR/utils/model_utils.py

457 lines
20 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
def truncated_normal_(tensor, mean=0, std=1):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
def init_params(module, initializer='normal'):
if isinstance(module, nn.Linear):
if initializer == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight.data)
elif initializer == 'normal':
nn.init.normal_(module.weight.data, std=0.02)
elif initializer == 'truncated_normal':
truncated_normal_(module.weight.data, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias.data)
# log.info('initialized Linear')
elif isinstance(module, nn.Embedding):
if initializer == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight.data)
elif initializer == 'normal':
nn.init.normal_(module.weight.data, std=0.02)
elif initializer == 'truncated_normal':
truncated_normal_(module.weight.data, std=0.02)
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight, mode='fan_out')
# log.info('initialized Conv')
elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell):
for name, param in module.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param.data)
elif 'bias' in name:
nn.init.normal_(param.data)
# log.info('initialized LSTM')
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
# log.info('initialized BatchNorm')
def TensorboardWriter(save_path):
from torch.utils.tensorboard import SummaryWriter
return SummaryWriter(save_path, comment="Unmt")
DEFAULT_EPS = 1e-8
PADDED_Y_VALUE = -1
def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
y_pred_shuffled = y_pred[:, random_indices]
y_true_shuffled = y_true[:, random_indices]
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
mask = y_true_sorted == padded_value_indicator
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
preds_sorted_by_true[mask] = float("-inf")
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0
return torch.mean(torch.sum(observation_loss, dim=1))
def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.):
"""
Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of
Information Retrieval Measures". Please note that this method does not implement any kind of truncation.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param alpha: score difference weight used in the sigmoid function
:return: loss value, a torch.Tensor
"""
device = y_pred.device
y_pred = y_pred.clone()
y_true = y_true.clone()
padded_mask = y_true == padded_value_indicator
y_pred[padded_mask] = float("-inf")
y_true[padded_mask] = float("-inf")
# Here we sort the true and predicted relevancy scores.
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
padded_pairs_mask = torch.isfinite(true_diffs)
padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_()
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
true_sorted_by_preds.clamp_(min=0.)
y_true_sorted.clamp_(min=0.)
# Here we find the gains, discounts and ideal DCGs per slate.
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
D = torch.log2(1. + pos_idxs.float())[None, :]
maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps)
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
# Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21)
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
scores_diffs[~padded_pairs_mask] = 0.
approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)),
dim=-1)
approx_D = torch.log2(1. + approx_pos)
approx_NDCG = torch.sum((G / approx_D), dim=-1)
return -torch.mean(approx_NDCG)
# return -torch.mean(approx_NDCG)
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
y_pred = y_pred.clone()
y_true = y_true.clone()
mask = y_true == padded_value_indicator
y_pred[mask] = float('-inf')
y_true[mask] = float('-inf')
preds_smax = F.softmax(y_pred, dim=1)
true_smax = F.softmax(y_true, dim=1)
preds_smax = preds_smax + eps
preds_log = torch.log(preds_smax)
return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
def deterministic_neural_sort(s, tau, mask):
"""
Deterministic neural sort.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param s: values to sort, shape [batch_size, slate_length]
:param tau: temperature for the final softmax function
:param mask: mask indicating padded elements
:return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
"""
dev = s.device
n = s.size()[1]
one = torch.ones((n, 1), dtype=torch.float32, device=dev)
s = s.masked_fill(mask[:, :, None], -1e8)
A_s = torch.abs(s - s.permute(0, 2, 1))
A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)
B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
temp = [t.type(torch.float32) for t in temp]
temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore
s = s.masked_fill(mask[:, :, None], 0.0)
C = torch.matmul(s, scaling.unsqueeze(-2))
P_max = (C - B).permute(0, 2, 1)
P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
sm = torch.nn.Softmax(-1)
P_hat = sm(P_max / tau)
return P_hat
def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor:
"""
Sampling from Gumbel distribution.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param samples_shape: shape of the output samples tensor
:param device: device of the output samples tensor
:param eps: epsilon for the logarithm function
:return: Gumbel samples tensor of shape samples_shape
"""
U = torch.rand(samples_shape, device=device)
return -torch.log(-torch.log(U + eps) + eps)
def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE):
mask = y_true == padding_indicator
y_pred[mask] = float('-inf')
y_true[mask] = 0.0
_, indices = y_pred.sort(descending=True, dim=-1)
return torch.gather(y_true, dim=1, index=indices)
def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
"""
Discounted Cumulative Gain at k.
Compute DCG at ranks given by ats or at the maximum rank if ats is None.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
:param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
:param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
"""
y_true = y_true.clone()
y_pred = y_pred.clone()
actual_length = y_true.shape[1]
if ats is None:
ats = [actual_length]
ats = [min(at, actual_length) for at in ats]
true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)
discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
device=true_sorted_by_preds.device)
gains = gain_function(true_sorted_by_preds)
discounted_gains = (gains * discounts)[:, :np.max(ats)]
cum_dcg = torch.cumsum(discounted_gains, dim=1)
ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)
dcg = cum_dcg[:, ats_tensor]
return dcg
def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
"""
Sinkhorn scaling procedure.
:param mat: a tensor of square matrices of shape N x M x M, where N is batch size
:param mask: a tensor of masks of shape N x M
:param tol: Sinkhorn scaling tolerance
:param max_iter: maximum number of iterations of the Sinkhorn scaling
:return: a tensor of (approximately) doubly stochastic matrices
"""
if mask is not None:
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)
for _ in range(max_iter):
mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)
if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
break
if mask is not None:
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
return mat
def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
"""
Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param s: values to sort, shape [batch_size, slate_length]
:param n_samples: number of samples (approximations) for each permutation matrix
:param tau: temperature for the final softmax function
:param mask: mask indicating padded elements
:param beta: scale parameter for the Gumbel distribution
:param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
:param eps: epsilon for the logarithm function
:return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
"""
dev = s.device
batch_size = s.size()[0]
n = s.size()[1]
s_positive = s + torch.abs(s.min())
samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
if log_scores:
s_positive = torch.log(s_positive + eps)
s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
mask_repeated = mask.repeat_interleave(n_samples, dim=0)
P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
P_hat = P_hat.view(n_samples, batch_size, n, n)
return P_hat
def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
stochastic=False, n_samples=32, beta=0.1, log_scores=True):
"""
NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param temperature: temperature for the NeuralSort algorithm
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
:param k: rank at which the loss is truncated
:param stochastic: whether to calculate the stochastic variant
:param n_samples: how many stochastic samples are taken, used if stochastic == True
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
:return: loss value, a torch.Tensor
"""
dev = y_pred.device
if k is None:
k = y_true.shape[1]
mask = (y_true == padded_value_indicator)
# Choose the deterministic/stochastic variant
if stochastic:
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
beta=beta, log_scores=log_scores)
else:
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])
# Mask P_hat and apply to true labels, ie approximately sort them
P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
if powered_relevancies:
y_true_masked = torch.pow(2., y_true_masked) - 1.
ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
discounted_gains = ground_truth * discounts
if powered_relevancies:
idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
else:
idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)
discounted_gains = discounted_gains[:, :, :k]
ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
idcg_mask = idcg == 0.
ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
if idcg_mask.all():
return torch.tensor(0.)
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
max_iter=50, tol=1e-6):
"""
NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param temperature: temperature for the NeuralSort algorithm
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
:param k: rank at which the loss is truncated
:param stochastic: whether to calculate the stochastic variant
:param n_samples: how many stochastic samples are taken, used if stochastic == True
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
:param max_iter: maximum iteration count for Sinkhorn scaling
:param tol: tolerance for Sinkhorn scaling
:return: loss value, a torch.Tensor
"""
dev = y_pred.device
if k is None:
k = y_true.shape[1]
mask = (y_true == padded_value_indicator)
if stochastic:
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
beta=beta, log_scores=log_scores)
else:
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
# This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
discounts[k:] = 0.
discounts = discounts[None, None, :, None]
# Here the discounts become expected discounts
discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
if powered_relevancies:
gains = torch.pow(2., y_true) - 1
discounted_gains = gains.unsqueeze(0) * discounts
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
else:
gains = y_true
discounted_gains = gains.unsqueeze(0) * discounts
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
idcg_mask = idcg == 0.
ndcg = ndcg.masked_fill(idcg_mask, 0.)
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
if idcg_mask.all():
return torch.tensor(0.)
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
return -1. * mean_ndcg # -1 cause we want to maximize NDCG