make code public
This commit is contained in:
commit
9d8b93db26
26 changed files with 11937 additions and 0 deletions
79
prog_generator/optim.py
Normal file
79
prog_generator/optim.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
author: Adnen Abdessaied
|
||||
maintainer: "Adnen Abdessaied"
|
||||
website: adnenabdessaied.de
|
||||
version: 1.0.1
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# adapted from https://github.com/MILVLG/mcan-vqa/blob/master/core/model/optim.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.optim as Optim
|
||||
|
||||
|
||||
class WarmupOptimizer(object):
|
||||
def __init__(self, lr_base, optimizer, data_size, batch_size):
|
||||
self.optimizer = optimizer
|
||||
self._step = 0
|
||||
self.lr_base = lr_base
|
||||
self._rate = 0
|
||||
self.data_size = data_size
|
||||
self.batch_size = batch_size
|
||||
|
||||
def step(self):
|
||||
self._step += 1
|
||||
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p['lr'] = rate
|
||||
self._rate = rate
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def rate(self, step=None):
|
||||
if step is None:
|
||||
step = self._step
|
||||
|
||||
if step <= int(self.data_size / self.batch_size * 1):
|
||||
r = self.lr_base * 1/2.
|
||||
else:
|
||||
r = self.lr_base
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def get_optim(opts, model, data_size, lr_base=None):
|
||||
if lr_base is None:
|
||||
lr_base = opts.lr
|
||||
|
||||
if opts.optim == 'adam':
|
||||
optim = Optim.Adam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=0,
|
||||
betas=opts.betas,
|
||||
eps=opts.eps,
|
||||
|
||||
)
|
||||
elif opts.optim == 'rmsprop':
|
||||
optim = Optim.RMSprop(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=0,
|
||||
eps=opts.eps,
|
||||
weight_decay=opts.weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError('{} optimizer is not supported'.fromat(opts.optim))
|
||||
return WarmupOptimizer(
|
||||
lr_base,
|
||||
optim,
|
||||
data_size,
|
||||
opts.batch_size
|
||||
)
|
||||
|
||||
def adjust_lr(optim, decay_r):
|
||||
optim.lr_base *= decay_r
|
Loading…
Add table
Add a link
Reference in a new issue