neuro-symbolic-visual-dialog/prog_generator/optim.py

80 lines
1.9 KiB
Python
Raw Normal View History

2022-08-10 16:49:55 +02:00
"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
# --------------------------------------------------------
# adapted from https://github.com/MILVLG/mcan-vqa/blob/master/core/model/optim.py
# --------------------------------------------------------
import torch
import torch.optim as Optim
class WarmupOptimizer(object):
def __init__(self, lr_base, optimizer, data_size, batch_size):
self.optimizer = optimizer
self._step = 0
self.lr_base = lr_base
self._rate = 0
self.data_size = data_size
self.batch_size = batch_size
def step(self):
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def zero_grad(self):
self.optimizer.zero_grad()
def rate(self, step=None):
if step is None:
step = self._step
if step <= int(self.data_size / self.batch_size * 1):
r = self.lr_base * 1/2.
else:
r = self.lr_base
return r
def get_optim(opts, model, data_size, lr_base=None):
if lr_base is None:
lr_base = opts.lr
if opts.optim == 'adam':
optim = Optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0,
betas=opts.betas,
eps=opts.eps,
)
elif opts.optim == 'rmsprop':
optim = Optim.RMSprop(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0,
eps=opts.eps,
weight_decay=opts.weight_decay
)
else:
raise ValueError('{} optimizer is not supported'.fromat(opts.optim))
return WarmupOptimizer(
lr_base,
optim,
data_size,
opts.batch_size
)
def adjust_lr(optim, decay_r):
optim.lr_base *= decay_r