import torch from torch import nn import torch.nn.functional as F import numpy as np def truncated_normal_(tensor, mean=0, std=1): size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean) def init_params(module, initializer='normal'): if isinstance(module, nn.Linear): if initializer == 'kaiming_normal': nn.init.kaiming_normal_(module.weight.data) elif initializer == 'normal': nn.init.normal_(module.weight.data, std=0.02) elif initializer == 'truncated_normal': truncated_normal_(module.weight.data, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias.data) # log.info('initialized Linear') elif isinstance(module, nn.Embedding): if initializer == 'kaiming_normal': nn.init.kaiming_normal_(module.weight.data) elif initializer == 'normal': nn.init.normal_(module.weight.data, std=0.02) elif initializer == 'truncated_normal': truncated_normal_(module.weight.data, std=0.02) elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight, mode='fan_out') # log.info('initialized Conv') elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell): for name, param in module.named_parameters(): if 'weight' in name: nn.init.orthogonal_(param.data) elif 'bias' in name: nn.init.normal_(param.data) # log.info('initialized LSTM') elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d): module.weight.data.normal_(1.0, 0.02) # log.info('initialized BatchNorm') def TensorboardWriter(save_path): from torch.utils.tensorboard import SummaryWriter return SummaryWriter(save_path, comment="Unmt") DEFAULT_EPS = 1e-8 PADDED_Y_VALUE = -1 def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE): """ ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm". :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param eps: epsilon value, used for numerical stability :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :return: loss value, a torch.Tensor """ # shuffle for randomised tie resolution random_indices = torch.randperm(y_pred.shape[-1]) y_pred_shuffled = y_pred[:, random_indices] y_true_shuffled = y_true[:, random_indices] y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1) mask = y_true_sorted == padded_value_indicator preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices) preds_sorted_by_true[mask] = float("-inf") max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True) preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1]) observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max observation_loss[mask] = 0.0 return torch.mean(torch.sum(observation_loss, dim=1)) def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.): """ Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of Information Retrieval Measures". Please note that this method does not implement any kind of truncation. :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param eps: epsilon value, used for numerical stability :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :param alpha: score difference weight used in the sigmoid function :return: loss value, a torch.Tensor """ device = y_pred.device y_pred = y_pred.clone() y_true = y_true.clone() padded_mask = y_true == padded_value_indicator y_pred[padded_mask] = float("-inf") y_true[padded_mask] = float("-inf") # Here we sort the true and predicted relevancy scores. y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) y_true_sorted, _ = y_true.sort(descending=True, dim=-1) # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element. true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] padded_pairs_mask = torch.isfinite(true_diffs) padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_() # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs) true_sorted_by_preds.clamp_(min=0.) y_true_sorted.clamp_(min=0.) # Here we find the gains, discounts and ideal DCGs per slate. pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) D = torch.log2(1. + pos_idxs.float())[None, :] maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps) G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] # Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21) scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]) scores_diffs[~padded_pairs_mask] = 0. approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)), dim=-1) approx_D = torch.log2(1. + approx_pos) approx_NDCG = torch.sum((G / approx_D), dim=-1) return -torch.mean(approx_NDCG) # return -torch.mean(approx_NDCG) def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE): """ ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach". :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param eps: epsilon value, used for numerical stability :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :return: loss value, a torch.Tensor """ y_pred = y_pred.clone() y_true = y_true.clone() mask = y_true == padded_value_indicator y_pred[mask] = float('-inf') y_true[mask] = float('-inf') preds_smax = F.softmax(y_pred, dim=1) true_smax = F.softmax(y_true, dim=1) preds_smax = preds_smax + eps preds_log = torch.log(preds_smax) return torch.mean(-torch.sum(true_smax * preds_log, dim=1)) def deterministic_neural_sort(s, tau, mask): """ Deterministic neural sort. Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. Minor modifications applied to the original code (masking). :param s: values to sort, shape [batch_size, slate_length] :param tau: temperature for the final softmax function :param mask: mask indicating padded elements :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length] """ dev = s.device n = s.size()[1] one = torch.ones((n, 1), dtype=torch.float32, device=dev) s = s.masked_fill(mask[:, :, None], -1e8) A_s = torch.abs(s - s.permute(0, 2, 1)) A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0) B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1))) temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)] temp = [t.type(torch.float32) for t in temp] temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp] scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore s = s.masked_fill(mask[:, :, None], 0.0) C = torch.matmul(s, scaling.unsqueeze(-2)) P_max = (C - B).permute(0, 2, 1) P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf) P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0) sm = torch.nn.Softmax(-1) P_hat = sm(P_max / tau) return P_hat def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor: """ Sampling from Gumbel distribution. Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. Minor modifications applied to the original code (masking). :param samples_shape: shape of the output samples tensor :param device: device of the output samples tensor :param eps: epsilon for the logarithm function :return: Gumbel samples tensor of shape samples_shape """ U = torch.rand(samples_shape, device=device) return -torch.log(-torch.log(U + eps) + eps) def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE): mask = y_true == padding_indicator y_pred[mask] = float('-inf') y_true[mask] = 0.0 _, indices = y_pred.sort(descending=True, dim=-1) return torch.gather(y_true, dim=1, index=indices) def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE): """ Discounted Cumulative Gain at k. Compute DCG at ranks given by ats or at the maximum rank if ats is None. :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1 :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)] """ y_true = y_true.clone() y_pred = y_pred.clone() actual_length = y_true.shape[1] if ats is None: ats = [actual_length] ats = [min(at, actual_length) for at in ats] true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator) discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to( device=true_sorted_by_preds.device) gains = gain_function(true_sorted_by_preds) discounted_gains = (gains * discounts)[:, :np.max(ats)] cum_dcg = torch.cumsum(discounted_gains, dim=1) ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1) dcg = cum_dcg[:, ats_tensor] return dcg def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50): """ Sinkhorn scaling procedure. :param mat: a tensor of square matrices of shape N x M x M, where N is batch size :param mask: a tensor of masks of shape N x M :param tol: Sinkhorn scaling tolerance :param max_iter: maximum number of iterations of the Sinkhorn scaling :return: a tensor of (approximately) doubly stochastic matrices """ if mask is not None: mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0) mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0) for _ in range(max_iter): mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS) mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS) if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol: break if mask is not None: mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0) return mat def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10): """ Stochastic neural sort. Please note that memory complexity grows by factor n_samples. Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. Minor modifications applied to the original code (masking). :param s: values to sort, shape [batch_size, slate_length] :param n_samples: number of samples (approximations) for each permutation matrix :param tau: temperature for the final softmax function :param mask: mask indicating padded elements :param beta: scale parameter for the Gumbel distribution :param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation :param eps: epsilon for the logarithm function :return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length] """ dev = s.device batch_size = s.size()[0] n = s.size()[1] s_positive = s + torch.abs(s.min()) samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev) if log_scores: s_positive = torch.log(s_positive + eps) s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1) mask_repeated = mask.repeat_interleave(n_samples, dim=0) P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated) P_hat = P_hat.view(n_samples, batch_size, n, n) return P_hat def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True): """ NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm. :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :param temperature: temperature for the NeuralSort algorithm :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise :param k: rank at which the loss is truncated :param stochastic: whether to calculate the stochastic variant :param n_samples: how many stochastic samples are taken, used if stochastic == True :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True :return: loss value, a torch.Tensor """ dev = y_pred.device if k is None: k = y_true.shape[1] mask = (y_true == padded_value_indicator) # Choose the deterministic/stochastic variant if stochastic: P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask, beta=beta, log_scores=log_scores) else: P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0) # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]), mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50) P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2]) # Mask P_hat and apply to true labels, ie approximately sort them P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.) y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0) if powered_relevancies: y_true_masked = torch.pow(2., y_true_masked) - 1. ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1) discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev) discounted_gains = ground_truth * discounts if powered_relevancies: idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0) else: idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0) discounted_gains = discounted_gains[:, :, :k] ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS) idcg_mask = idcg == 0. ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.) assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative" if idcg_mask.all(): return torch.tensor(0.) mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore return -1. * mean_ndcg # -1 cause we want to maximize NDCG def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True, max_iter=50, tol=1e-6): """ NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm. :param y_pred: predictions from the model, shape [batch_size, slate_length] :param y_true: ground truth labels, shape [batch_size, slate_length] :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 :param temperature: temperature for the NeuralSort algorithm :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise :param k: rank at which the loss is truncated :param stochastic: whether to calculate the stochastic variant :param n_samples: how many stochastic samples are taken, used if stochastic == True :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True :param max_iter: maximum iteration count for Sinkhorn scaling :param tol: tolerance for Sinkhorn scaling :return: loss value, a torch.Tensor """ dev = y_pred.device if k is None: k = y_true.shape[1] mask = (y_true == padded_value_indicator) if stochastic: P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask, beta=beta, log_scores=log_scores) else: P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0) # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]), mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter) P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]) discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev) # This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount discounts[k:] = 0. discounts = discounts[None, None, :, None] # Here the discounts become expected discounts discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1) if powered_relevancies: gains = torch.pow(2., y_true) - 1 discounted_gains = gains.unsqueeze(0) * discounts idcg = dcg(y_true, y_true, ats=[k]).squeeze() else: gains = y_true discounted_gains = gains.unsqueeze(0) * discounts idcg = dcg(y_true, y_true, ats=[k]).squeeze() ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS) idcg_mask = idcg == 0. ndcg = ndcg.masked_fill(idcg_mask, 0.) assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative" if idcg_mask.all(): return torch.tensor(0.) mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore return -1. * mean_ndcg # -1 cause we want to maximize NDCG