257 lines
11 KiB
Python
257 lines
11 KiB
Python
import keras.backend as K
|
|
import numpy as np
|
|
from sal_imp_utilities import *
|
|
from tensorflow.keras.losses import KLDivergence
|
|
|
|
# KL-Divergence Loss
|
|
def kl_divergence(y_true, y_pred):
|
|
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
|
|
y_pred /= max_y_pred
|
|
|
|
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
y_true /= (sum_y_true + K.epsilon())
|
|
y_pred /= (sum_y_pred + K.epsilon())
|
|
|
|
|
|
# This constant was defined by Cornia et al. and is a bit arbitrary
|
|
return K.sum(K.sum(y_true * K.log((y_true / (y_pred + K.epsilon())) + K.epsilon()), axis=1), axis=1)
|
|
|
|
def kl_time(y_true, y_pred):
|
|
if len(y_true.shape) == 5:
|
|
ax = 2
|
|
else:
|
|
ax = 1
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
|
|
y_pred /= max_y_pred
|
|
|
|
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
y_true /= (sum_y_true + K.epsilon())
|
|
y_pred /= (sum_y_pred + K.epsilon())
|
|
|
|
kl_out = K.sum(K.sum(y_true * K.log((y_true / (y_pred + K.epsilon())) + K.epsilon()), axis=ax), axis=ax)
|
|
|
|
if len(y_true.shape) == 5:
|
|
kl_out = K.mean(kl_out, axis = 1)
|
|
|
|
return kl_out
|
|
|
|
# Correlation Coefficient Loss
|
|
def correlation_coefficient(y_true, y_pred):
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
y_pred /= max_y_pred
|
|
|
|
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
y_true /= (sum_y_true + K.epsilon())
|
|
y_pred /= (sum_y_pred + K.epsilon())
|
|
|
|
N = shape_r_out * shape_c_out
|
|
sum_prod = K.sum(K.sum(y_true * y_pred, axis=1), axis=1)
|
|
sum_x = K.sum(K.sum(y_true, axis=1), axis=1)
|
|
sum_y = K.sum(K.sum(y_pred, axis=1), axis=1)
|
|
sum_x_square = K.sum(K.sum(K.square(y_true), axis=1), axis=1)
|
|
sum_y_square = K.sum(K.sum(K.square(y_pred), axis=1), axis=1)
|
|
|
|
num = sum_prod - ((sum_x * sum_y) / N)
|
|
den = K.sqrt((sum_x_square - K.square(sum_x) / N) * (sum_y_square - K.square(sum_y) / N))
|
|
|
|
return num / den
|
|
|
|
def cc_time(y_true, y_pred):
|
|
if len(y_true.shape) == 5:
|
|
ax = 2
|
|
else:
|
|
ax = 1
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
y_pred /= max_y_pred
|
|
|
|
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=ax), axis=ax), axis=ax),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
y_true /= (sum_y_true + K.epsilon())
|
|
y_pred /= (sum_y_pred + K.epsilon())
|
|
|
|
N = shape_r_out * shape_c_out
|
|
sum_prod = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
|
|
sum_x = K.sum(K.sum(y_true, axis=ax), axis=ax)
|
|
sum_y = K.sum(K.sum(y_pred, axis=ax), axis=ax)
|
|
sum_x_square = K.sum(K.sum(K.square(y_true), axis=ax), axis=ax)
|
|
sum_y_square = K.sum(K.sum(K.square(y_pred), axis=ax), axis=ax)
|
|
|
|
|
|
num = sum_prod - ((sum_x * sum_y) / N)
|
|
den = K.sqrt((sum_x_square - K.square(sum_x) / N) * (sum_y_square - K.square(sum_y) / N))
|
|
|
|
if len(y_true.shape) == 5:
|
|
cc_out = K.mean(num / den, axis = 1)
|
|
else:
|
|
cc_out = num / den
|
|
|
|
return cc_out
|
|
|
|
# Normalized Scanpath Saliency Loss
|
|
def nss_time(y_true, y_pred):
|
|
if len(y_true.shape) == 5:
|
|
ax = 2
|
|
else:
|
|
ax = 1
|
|
|
|
maxi = K.max(K.max(y_pred, axis=ax), axis=ax)
|
|
first_rep = K.repeat_elements(K.expand_dims(maxi, axis=ax),shape_r_out, axis=ax)
|
|
max_y_pred = K.repeat_elements(K.expand_dims(first_rep, axis=ax+1), shape_c_out, axis=ax+1)
|
|
y_pred /= max_y_pred
|
|
|
|
if len(y_true.shape) == 5:
|
|
y_pred_flatten = K.reshape(y_pred, (K.shape(y_pred)[0],K.shape(y_pred)[1],K.shape(y_pred)[2]*K.shape(y_pred)[3]*K.shape(y_pred)[4])) #K.batch_flatten(y_pred)
|
|
else:
|
|
y_pred_flatten = K.batch_flatten(y_pred)
|
|
|
|
y_mean = K.mean(y_pred_flatten, axis=-1)
|
|
y_mean = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_mean)),
|
|
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
|
|
|
|
y_std = K.std(y_pred_flatten, axis=-1)
|
|
y_std = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_std)),
|
|
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
|
|
|
|
y_pred = (y_pred - y_mean) / (y_std + K.epsilon())
|
|
|
|
num = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
|
|
den = K.sum(K.sum(y_true, axis=ax), axis=ax) + K.epsilon()
|
|
|
|
if len(y_true.shape) == 5:
|
|
nss_out = K.mean(num/den, axis = 1)
|
|
else:
|
|
nss_out = num/den
|
|
|
|
return nss_out
|
|
|
|
|
|
def nss(y_true, y_pred):
|
|
|
|
ax = 1
|
|
|
|
if K.sum(K.sum(y_true, axis=ax), axis=ax) == 0:
|
|
return 0
|
|
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=ax), axis=ax), axis=ax+1),
|
|
shape_r_out, axis=ax), axis=ax+1), shape_c_out, axis=ax+1)
|
|
|
|
y_pred /= max_y_pred
|
|
|
|
|
|
y_pred_flatten = K.batch_flatten(y_pred)
|
|
|
|
y_mean = K.mean(y_pred_flatten, axis=-1)
|
|
y_mean = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_mean)),
|
|
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
|
|
|
|
y_std = K.std(y_pred_flatten, axis=-1)
|
|
y_std = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.expand_dims(y_std)),
|
|
shape_r_out, axis=ax)), shape_c_out, axis=ax+1)
|
|
|
|
y_pred = (y_pred - y_mean) / (y_std + K.epsilon())
|
|
|
|
den = K.sum(K.sum(y_true * y_pred, axis=ax), axis=ax)
|
|
nom = K.sum(K.sum(y_true, axis=ax), axis=ax) + K.epsilon()
|
|
|
|
nss_out = den/nom
|
|
|
|
return nss_out
|
|
|
|
def cc_match(y_true, y_pred):
|
|
'''Calculates CC between initial, mid and final timestep from both y_true and y_pred
|
|
and calculates the mean absolute error between the CCs from y_true and from y_pred.
|
|
Requires a y_true and y_pred to be tensors of shape (bs, t, r, c, 1)'''
|
|
|
|
mid = 1 # y_true.shape[1].value//2
|
|
ccim_true = cc_time(y_true[:,0,...], y_true[:,mid,...])
|
|
ccmf_true = cc_time(y_true[:,mid,...], y_true[:,-1,...])
|
|
|
|
ccim_pred = cc_time(y_pred[:,0,...], y_pred[:,mid,...])
|
|
ccmf_pred = cc_time(y_pred[:,mid,...], y_pred[:,-1,...])
|
|
|
|
return (K.abs(ccim_true-ccim_pred) + K.abs(ccmf_true-ccmf_pred) )/2
|
|
|
|
def kl_cc_nss_combined(lw=[10,-2,-1]):
|
|
# DEPRECATED
|
|
'''Loss function that combines cc, nss and kl. Beacuse nss receives a different ground truth than kl and cc (maps),
|
|
the function requires y_true to contains both maps. It has to be a tensor with dimensions [bs, 2, r, c, 1]. y_pred also
|
|
has to be a tensor of the same dim, so the model should add a 5th dimension between bs and r and repeat the predict map
|
|
twice along that dim.
|
|
'''
|
|
def loss(y_true, y_pred):
|
|
|
|
map_true = y_true[:,0,...]
|
|
fix_true = y_true[:,1,...]
|
|
pred = y_pred[:,0,...]
|
|
|
|
|
|
k = kl_divergence(map_true, pred)
|
|
c = correlation_coefficient(map_true, pred)
|
|
n = nss(fix_true, pred)
|
|
|
|
|
|
return lw[0]*k+lw[1]*c+lw[2]*n
|
|
|
|
return loss
|
|
|
|
|
|
def loss_wrapper(loss, input_shape):
|
|
shape_r_out, shape_c_out = input_shape
|
|
print("shape r out, shape c out", shape_r_out, shape_c_out)
|
|
def _wrapper(y_true, y_pred):
|
|
return loss(y_true, y_pred)
|
|
return _wrapper
|
|
|
|
|
|
def kl_new(y_true, y_pred):
|
|
'''
|
|
This function is for singleduration model. The old kl_divergence() may cause nan in training.
|
|
'''
|
|
max_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
y_pred /= max_y_pred
|
|
|
|
sum_y_true = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_true, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
sum_y_pred = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.sum(K.sum(y_pred, axis=1), axis=1), axis=1),
|
|
shape_r_out, axis=1), axis=2), shape_c_out, axis=2)
|
|
y_true /= (sum_y_true + K.epsilon())
|
|
y_pred /= (sum_y_pred + K.epsilon())
|
|
kl = tf.keras.losses.KLDivergence()
|
|
return kl(y_true,y_pred)
|
|
|
|
def kl_cc_combined(y_true, y_pred):
|
|
# For Singleduration
|
|
'''Loss function that combines cc, nss and kl. Beacuse nss receives a different ground truth than kl and cc (maps),
|
|
the function requires y_true to contains both maps. It has to be a tensor with dimensions [bs, 2, r, c, 1]. y_pred also
|
|
has to be a tensor of the same dim, so the model should add a 5th dimension between bs and r and repeat the predict map
|
|
twice along that dim.
|
|
'''
|
|
|
|
#k = kl_time(y_true, y_pred)
|
|
k = kl_new(y_true, y_pred)
|
|
print('k=',k)
|
|
#c = cc_time(y_true, y_pred)
|
|
c = correlation_coefficient(y_true, y_pred)
|
|
print('c=', c)
|
|
return 10*k-3*c
|
|
|