visrecall/RecallNet/src/util.py

68 lines
2.4 KiB
Python

from singleduration_models import sam_resnet_new, UMSI, RecallNet_UMSI, RecallNet_xception, RecallNet_xception_aspp
from losses_keras2 import loss_wrapper, kl_time, cc_time, nss_time, cc_match, kl_cc_combined
MODELS = {
'sam-resnet': (sam_resnet_new, 'simple'),
"UMSI": (UMSI, "simple"),
'RecallNet_xception':(RecallNet_xception,'simple'),
'RecallNet_xception_aspp':(RecallNet_xception_aspp,'simple'),
'RecallNet_UMSI':(RecallNet_UMSI,'simple')
}
LOSSES = {
'kl': (kl_time, 'heatmap'),
'cc': (cc_time, 'heatmap'),
'nss': (nss_time, 'fixmap'),
'ccmatch': (cc_match, 'heatmap'),
"kl+cc": (kl_cc_combined, "heatmap")
}
def get_model_by_name(name):
""" Returns a model and a string indicating its mode of use."""
if name not in MODELS:
allowed_models = list(MODELS.keys())
raise RuntimeError("Model %s is not recognized. Please choose one of: %s" % (name, ",".join(allowed_models)))
else:
return MODELS[name]
def get_loss_by_name(name, out_size):
"""Gets the loss associated with a certain name.
If there is no custom loss associated with name `name`, returns the string
`name` so that keras can interpret it as a keras loss.
"""
if name not in LOSSES:
print("WARNING: found no custom loss with name %s, defaulting to a string." % name)
return name, 'heatmap'
else:
loss, out_type = LOSSES[name]
loss = loss_wrapper(loss, out_size)
return loss, out_type
def create_losses(loss_dict, out_size):
"""Given a dictionary that maps loss names to weights, returns loss functions and weights in the correct order.
By convention, losses that take in a heatmap (as opposed to a fixmap) come first in the array of losses. This function enforces that convention.
This function looks up the correct loss function by name and outputs the correct functions, ordering, and weights to pass to the model/generator.
"""
l_hm = []
l_hm_w = []
l_fm = []
l_fm_w = []
lstr = ""
for lname, wt in loss_dict.items():
loss, out_type = get_loss_by_name(lname, out_size)
if out_type == 'heatmap':
l_hm.append(loss)
l_hm_w.append(wt)
else:
l_fm.append(loss)
l_fm_w.append(wt)
lstr += lname + str(wt)
l = l_hm + l_fm
lw = l_hm_w + l_fm_w
n_heatmaps = len(l_hm)
return l, lw, lstr, n_heatmaps