visrecall/RecallNet/src/losses_keras2.py

257 lines
11 KiB
Python

import keras.backend as K
import numpy as np
from sal_imp_utilities import *
from tensorflow.keras.losses import KLDivergence
# KL-Divergence Loss
def kl_divergence(y_true, y_pred):
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
# This constant was defined by Cornia et al. and is a bit arbitrary
return K.sum(K.sum(y_true * K.log((y_true / (y_pred + K.epsilon())) + K.epsilon()), axis=1), axis=1)
def kl_time(y_true, y_pred):
if len(y_true.shape) == 5:
ax = 2
else:
ax = 1
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
kl_out = K.sum(K.sum(y_true * K.log((y_true / (y_pred + K.epsilon())) + K.epsilon()), axis=ax), axis=ax)
if len(y_true.shape) == 5:
kl_out = K.mean(kl_out, axis = 1)
return kl_out
# Correlation Coefficient Loss
def correlation_coefficient(y_true, y_pred):
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
N = shape_r_out * shape_c_out
sum_prod = K.sum(K.sum(y_true * y_pred, axis=1), axis=1)
sum_x = K.sum(K.sum(y_true, axis=1), axis=1)
sum_y = K.sum(K.sum(y_pred, axis=1), axis=1)
sum_x_square = K.sum(K.sum(K.square(y_true), axis=1), axis=1)
sum_y_square = K.sum(K.sum(K.square(y_pred), axis=1), axis=1)
num = sum_prod - ((sum_x * sum_y) / N)
den = K.sqrt((sum_x_square - K.square(sum_x) / N) * (sum_y_square - K.square(sum_y) / N))
return num / den
def cc_time(y_true, y_pred):
if len(y_true.shape) == 5:
ax = 2
else:
ax = 1
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=ax), axis=ax), axis=ax),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
N = shape_r_out * shape_c_out
sum_prod = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
sum_x = K.sum(K.sum(y_true, axis=ax), axis=ax)
sum_y = K.sum(K.sum(y_pred, axis=ax), axis=ax)
sum_x_square = K.sum(K.sum(K.square(y_true), axis=ax), axis=ax)
sum_y_square = K.sum(K.sum(K.square(y_pred), axis=ax), axis=ax)
num = sum_prod - ((sum_x * sum_y) / N)
den = K.sqrt((sum_x_square - K.square(sum_x) / N) * (sum_y_square - K.square(sum_y) / N))
if len(y_true.shape) == 5:
cc_out = K.mean(num / den, axis = 1)
else:
cc_out = num / den
return cc_out
# Normalized Scanpath Saliency Loss
def nss_time(y_true, y_pred):
if len(y_true.shape) == 5:
ax = 2
else:
ax = 1
maxi = K.max(K.max(y_pred, axis=ax), axis=ax)
first_rep = K.repeat_elements(K.expand_dims(maxi, axis=ax),shape_r_out, axis=ax)
max_y_pred = K.repeat_elements(K.expand_dims(first_rep, axis=ax+1), shape_c_out, axis=ax+1)
y_pred /= max_y_pred
if len(y_true.shape) == 5:
y_pred_flatten = K.reshape(y_pred, (K.shape(y_pred)[0],K.shape(y_pred)[1],K.shape(y_pred)[2]*K.shape(y_pred)[3]*K.shape(y_pred)[4])) #K.batch_flatten(y_pred)
else:
y_pred_flatten = K.batch_flatten(y_pred)
y_mean = K.mean(y_pred_flatten, axis=-1)
y_mean = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_mean)),
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
y_std = K.std(y_pred_flatten, axis=-1)
y_std = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_std)),
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
y_pred = (y_pred - y_mean) / (y_std + K.epsilon())
num = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
den = K.sum(K.sum(y_true, axis=ax), axis=ax) + K.epsilon()
if len(y_true.shape) == 5:
nss_out = K.mean(num/den, axis = 1)
else:
nss_out = num/den
return nss_out
def nss(y_true, y_pred):
ax = 1
if K.sum(K.sum(y_true, axis=ax), axis=ax) == 0:
return 0
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax+1),
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
y_pred /= max_y_pred
y_pred_flatten = K.batch_flatten(y_pred)
y_mean = K.mean(y_pred_flatten, axis=-1)
y_mean = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_mean)),
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
y_std = K.std(y_pred_flatten, axis=-1)
y_std = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_std)),
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
y_pred = (y_pred - y_mean) / (y_std + K.epsilon())
den = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
nom = K.sum(K.sum(y_true, axis=ax), axis=ax) + K.epsilon()
nss_out = den/nom
return nss_out
def cc_match(y_true, y_pred):
'''Calculates CC between initial, mid and final timestep from both y_true and y_pred
and calculates the mean absolute error between the CCs from y_true and from y_pred.
Requires a y_true and y_pred to be tensors of shape (bs, t, r, c, 1)'''
mid = 1 # y_true.shape[1].value//2
ccim_true = cc_time(y_true[:,0,...], y_true[:,mid,...])
ccmf_true = cc_time(y_true[:,mid,...], y_true[:,-1,...])
ccim_pred = cc_time(y_pred[:,0,...], y_pred[:,mid,...])
ccmf_pred = cc_time(y_pred[:,mid,...], y_pred[:,-1,...])
return (K.abs(ccim_true-ccim_pred) + K.abs(ccmf_true-ccmf_pred) )/2
def kl_cc_nss_combined(lw=[10,-2,-1]):
# DEPRECATED
'''Loss function that combines cc, nss and kl. Beacuse nss receives a different ground truth than kl and cc (maps),
the function requires y_true to contains both maps. It has to be a tensor with dimensions [bs, 2, r, c, 1]. y_pred also
has to be a tensor of the same dim, so the model should add a 5th dimension between bs and r and repeat the predict map
twice along that dim.
'''
def loss(y_true, y_pred):
map_true = y_true[:,0,...]
fix_true = y_true[:,1,...]
pred = y_pred[:,0,...]
k = kl_divergence(map_true, pred)
c = correlation_coefficient(map_true, pred)
n = nss(fix_true, pred)
return lw[0]*k+lw[1]*c+lw[2]*n
return loss
def loss_wrapper(loss, input_shape):
shape_r_out, shape_c_out = input_shape
print("shape r out, shape c out", shape_r_out, shape_c_out)
def _wrapper(y_true, y_pred):
return loss(y_true, y_pred)
return _wrapper
def kl_new(y_true, y_pred):
'''
This function is for singleduration model. The old kl_divergence() may cause nan in training.
'''
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_pred /= max_y_pred
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
y_true /= (sum_y_true + K.epsilon())
y_pred /= (sum_y_pred + K.epsilon())
kl = tf.keras.losses.KLDivergence()
return kl(y_true,y_pred)
def kl_cc_combined(y_true, y_pred):
# For Singleduration
'''Loss function that combines cc, nss and kl. Beacuse nss receives a different ground truth than kl and cc (maps),
the function requires y_true to contains both maps. It has to be a tensor with dimensions [bs, 2, r, c, 1]. y_pred also
has to be a tensor of the same dim, so the model should add a 5th dimension between bs and r and repeat the predict map
twice along that dim.
'''
#k = kl_time(y_true, y_pred)
k = kl_new(y_true, y_pred)
print('k=',k)
#c = cc_time(y_true, y_pred)
c = correlation_coefficient(y_true, y_pred)
print('c=', c)
return 10*k-3*c