vlcn/core/model/optim.py

99 lines
2.7 KiB
Python

# --------------------------------------------------------
# mcan-vqa (Deep Modular Co-Attention Networks)
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------
import torch
import torch.optim as Optim
class WarmupOptimizer(object):
def __init__(self, lr_base, optimizer, data_size, batch_size):
self.optimizer = optimizer
self._step = 0
self.lr_base = lr_base
self._rate = 0
self.data_size = data_size
self.batch_size = batch_size
def step(self):
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def zero_grad(self):
self.optimizer.zero_grad()
def rate(self, step=None):
if step is None:
step = self._step
if step <= int(self.data_size / self.batch_size * 1):
r = self.lr_base * 1/4.
elif step <= int(self.data_size / self.batch_size * 2):
r = self.lr_base * 2/4.
elif step <= int(self.data_size / self.batch_size * 3):
r = self.lr_base * 3/4.
else:
r = self.lr_base
return r
def get_optim(__C, model, data_size, optimizer, lr_base=None):
if lr_base is None:
lr_base = __C.LR_BASE
# modules = model._modules
# params_list = []
# for m in modules:
# if 'dnc' in m:
# params_list.append({
# 'params': filter(lambda p: p.requires_grad, modules[m].parameters()),
# 'lr': __C.LR_DNC_BASE,
# 'flag': True
# })
# else:
# params_list.append({
# 'params': filter(lambda p: p.requires_grad, modules[m].parameters()),
# })
if optimizer == 'adam':
optim = Optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0,
betas=__C.OPT_BETAS,
eps=__C.OPT_EPS,
)
elif optimizer == 'rmsprop':
optim = Optim.RMSprop(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0,
eps=__C.OPT_EPS,
weight_decay=__C.OPT_WEIGHT_DECAY
)
else:
raise ValueError('{} optimizer is not supported'.fromat(optimizer))
return WarmupOptimizer(
lr_base,
optim,
data_size,
__C.BATCH_SIZE
)
def adjust_lr(optim, decay_r):
optim.lr_base *= decay_r
def adjust_lr_dnc(optim, decay_r):
optim.lr_dnc_base *= decay_r