visrecall/RecallNet/src/cb.py

81 lines
2.5 KiB
Python

import keras
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os
import keras
from keras.callbacks import ModelCheckpoint
import numpy as np
import math
class Unfreeze(keras.callbacks.Callback):
def __init__(self,it_to_unfreeze):
self.it_to_unfreeze = it_to_unfreeze
self.c=0
self.frozen=True
def on_batch_end(self, batch, logs=None):
self.c+=1
if not self.c > self.it_to_unfreeze and self.frozen:
print('Iteration %d reached: UNFREEZING ENCODER' % self.c)
self.frozen=False
for layer in self.model.layers:
layer.trainable=True
class InteractivePlot(keras.callbacks.Callback):
def __init__(self):
pass
def on_train_begin(self, logs={}):
self.losses = []
self.logs = []
self.batchnr = 0
self.icount = 0
def on_train_end(self, logs={}):
pass
def on_epoch_end(self, epoch, logs={}):
self.batchnr = 0
loss_train = logs.get('loss')
self.losses.append(loss_train)
self.icount+=1
clear_output(wait=True)
plt.figure(figsize=(14,10))
train_vals = [self.losses]
desc = ['loss']
for i in range(len(train_vals)):
#plt.subplot(2, 3, i+1)
plt.plot(range(self.icount), train_vals[i], label=desc[i])
plt.legend()
#plt.savefig(self.logfile.replace('.txt', '.png'), bbox_inches='tight', format='png')
plt.show()
def on_batch_end(self, batch, logs=None):
self.batchnr+=1
if self.batchnr % 10 == 0:
self.on_epoch_end(epoch=0, logs=logs)
def ckpt_callback(model_name, dataset, l_str, bs, extra_str='',
period=1, save_weights_only=True,
ckpt_folder_path = '../../predimportance_shared/models/ckpt/'):
path = os.path.join(ckpt_folder_path, model_name)
if not os.path.exists(path):
os.makedirs(path)
filepath = os.path.join(path, model_name+'_'+dataset+'_'+l_str+'_bs'+str(bs)+extra_str+'_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')
cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=True, period=1)
return cb_chk
def step_decay(init_lr = 0.0001, drop = 0.1, epochs_drop = 3.0):
def inner(epoch):
lrate = init_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop))
if not (epoch+1)%epochs_drop:
print('Reducing lr. New lr is:', lrate)
return lrate
return inner