80 lines
1.9 KiB
Python
80 lines
1.9 KiB
Python
|
"""
|
||
|
author: Adnen Abdessaied
|
||
|
maintainer: "Adnen Abdessaied"
|
||
|
website: adnenabdessaied.de
|
||
|
version: 1.0.1
|
||
|
"""
|
||
|
|
||
|
# --------------------------------------------------------
|
||
|
# adapted from https://github.com/MILVLG/mcan-vqa/blob/master/core/model/optim.py
|
||
|
# --------------------------------------------------------
|
||
|
|
||
|
import torch
|
||
|
import torch.optim as Optim
|
||
|
|
||
|
|
||
|
class WarmupOptimizer(object):
|
||
|
def __init__(self, lr_base, optimizer, data_size, batch_size):
|
||
|
self.optimizer = optimizer
|
||
|
self._step = 0
|
||
|
self.lr_base = lr_base
|
||
|
self._rate = 0
|
||
|
self.data_size = data_size
|
||
|
self.batch_size = batch_size
|
||
|
|
||
|
def step(self):
|
||
|
self._step += 1
|
||
|
|
||
|
rate = self.rate()
|
||
|
for p in self.optimizer.param_groups:
|
||
|
p['lr'] = rate
|
||
|
self._rate = rate
|
||
|
|
||
|
self.optimizer.step()
|
||
|
|
||
|
def zero_grad(self):
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
def rate(self, step=None):
|
||
|
if step is None:
|
||
|
step = self._step
|
||
|
|
||
|
if step <= int(self.data_size / self.batch_size * 1):
|
||
|
r = self.lr_base * 1/2.
|
||
|
else:
|
||
|
r = self.lr_base
|
||
|
|
||
|
return r
|
||
|
|
||
|
|
||
|
def get_optim(opts, model, data_size, lr_base=None):
|
||
|
if lr_base is None:
|
||
|
lr_base = opts.lr
|
||
|
|
||
|
if opts.optim == 'adam':
|
||
|
optim = Optim.Adam(
|
||
|
filter(lambda p: p.requires_grad, model.parameters()),
|
||
|
lr=0,
|
||
|
betas=opts.betas,
|
||
|
eps=opts.eps,
|
||
|
|
||
|
)
|
||
|
elif opts.optim == 'rmsprop':
|
||
|
optim = Optim.RMSprop(
|
||
|
filter(lambda p: p.requires_grad, model.parameters()),
|
||
|
lr=0,
|
||
|
eps=opts.eps,
|
||
|
weight_decay=opts.weight_decay
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError('{} optimizer is not supported'.fromat(opts.optim))
|
||
|
return WarmupOptimizer(
|
||
|
lr_base,
|
||
|
optim,
|
||
|
data_size,
|
||
|
opts.batch_size
|
||
|
)
|
||
|
|
||
|
def adjust_lr(optim, decay_r):
|
||
|
optim.lr_base *= decay_r
|