VDGR/utils/optim_utils.py

390 lines
14 KiB
Python

import logging
import math
import numpy as np
import random
import functools
import glog as log
import torch
from torch import nn, optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from pytorch_transformers.optimization import AdamW
class WarmupLinearScheduleNonZero(_LRScheduler):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.min_lr = min_lr
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
lr_factor = float(step) / float(max(1, self.warmup_steps))
else:
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
def init_optim(model, config):
optimizer_grouped_parameters = []
gnn_params = []
encoder_params_with_decay = []
encoder_params_without_decay = []
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for module_name, module in model.named_children():
for param_name, param in module.named_parameters():
if param.requires_grad:
if "gnn" in param_name:
gnn_params.append(param)
elif module_name == 'encoder':
if any(ex in param_name for ex in exclude_from_weight_decay):
encoder_params_without_decay.append(param)
else:
encoder_params_with_decay.append(param)
optimizer_grouped_parameters = [
{
'params': gnn_params,
'weight_decay': config.gnn_weight_decay,
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
}
]
optimizer_grouped_parameters.extend(
[
{
'params': encoder_params_without_decay,
'weight_decay': 0,
'lr': config['learning_rate_bert']
},
{
'params': encoder_params_with_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_bert']
}
]
)
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
scheduler = WarmupLinearScheduleNonZero(
optimizer,
warmup_steps=config['warmup_steps'],
t_total=config['train_steps'],
min_lr=config['min_lr']
)
return optimizer, scheduler
def build_torch_optimizer(model, config):
"""Builds the PyTorch optimizer.
We use the default parameters for Adam that are suggested by
the original paper https://arxiv.org/pdf/1412.6980.pdf
These values are also used by other established implementations,
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
https://keras.io/optimizers/
Recently there are slightly different values used in the paper
"Attention is all you need"
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
was used there however, beta2=0.999 is still arguably the more
established value, so we use that here as well
Args:
model: The model to optimize.
config: The dictionary of options.
Returns:
A ``torch.optim.Optimizer`` instance.
"""
params = [p for p in model.parameters() if p.requires_grad]
betas = [0.9, 0.999]
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
params = {'bert': [], 'task': []}
for module_name, module in model.named_children():
if module_name == 'encoder':
param_type = 'bert'
else:
param_type = 'task'
for param_name, param in module.named_parameters():
if param.requires_grad:
if any(ex in param_name for ex in exclude_from_weight_decay):
params[param_type] += [
{
"params": [param],
"weight_decay": 0
}
]
else:
params[param_type] += [
{
"params": [param],
"weight_decay": 0.01
}
]
if config['task_optimizer'] == 'adamw':
log.info('Using AdamW as task optimizer')
task_optimizer = AdamWeightDecay(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
elif config['task_optimizer'] == 'adam':
log.info('Using Adam as task optimizer')
task_optimizer = optim.Adam(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
if len(params['bert']) > 0:
bert_optimizer = AdamWeightDecay(params['bert'],
lr=config["learning_rate_bert"],
betas=betas,
eps=1e-6)
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
else:
optimizer = task_optimizer
return optimizer
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
"""Returns the learning decay function from options."""
if decay_method == "linear":
return functools.partial(
linear_decay,
global_steps=train_steps,
**kwargs)
elif decay_method == "exp":
return functools.partial(
exp_decay,
global_steps=train_steps,
**kwargs)
else:
raise ValueError(f'{decay_method} not found')
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
end_learning_rate
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
end_learning_rate
class MultipleOptimizer(object):
""" Implement multiple optimizers needed for sparse adam """
def __init__(self, op):
""" ? """
self.optimizers = op
@property
def param_groups(self):
param_groups = []
for optimizer in self.optimizers:
param_groups.extend(optimizer.param_groups)
return param_groups
def zero_grad(self):
""" ? """
for op in self.optimizers:
op.zero_grad()
def step(self):
""" ? """
for op in self.optimizers:
op.step()
@property
def state(self):
""" ? """
return {k: v for op in self.optimizers for k, v in op.state.items()}
def state_dict(self):
""" ? """
return [op.state_dict() for op in self.optimizers]
def load_state_dict(self, state_dicts):
""" ? """
assert len(state_dicts) == len(self.optimizers)
for i in range(len(state_dicts)):
self.optimizers[i].load_state_dict(state_dicts[i])
class OptimizerBase(object):
"""
Controller class for optimization. Mostly a thin
wrapper for `optim`, but also useful for implementing
rate scheduling beyond what is currently available.
Also implements necessary methods for training RNNs such
as grad manipulations.
"""
def __init__(self,
optimizer,
learning_rate,
learning_rate_decay_fn=None,
max_grad_norm=None):
"""Initializes the controller.
Args:
optimizer: A ``torch.optim.Optimizer`` instance.
learning_rate: The initial learning rate.
learning_rate_decay_fn: An optional callable taking the current step
as argument and return a learning rate scaling factor.
max_grad_norm: Clip gradients to this global norm.
"""
self._optimizer = optimizer
self._learning_rate = learning_rate
self._learning_rate_decay_fn = learning_rate_decay_fn
self._max_grad_norm = max_grad_norm or 0
self._training_step = 1
self._decay_step = 1
@classmethod
def from_opt(cls, model, config, checkpoint=None):
"""Builds the optimizer from options.
Args:
cls: The ``Optimizer`` class to instantiate.
model: The model to optimize.
opt: The dict of user options.
checkpoint: An optional checkpoint to load states from.
Returns:
An ``Optimizer`` instance.
"""
optim_opt = config
optim_state_dict = None
if config["loads_ckpt"] and checkpoint is not None:
optim = checkpoint['optim']
ckpt_opt = checkpoint['opt']
ckpt_state_dict = {}
if isinstance(optim, Optimizer): # Backward compatibility.
ckpt_state_dict['training_step'] = optim._step + 1
ckpt_state_dict['decay_step'] = optim._step + 1
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
else:
ckpt_state_dict = optim
if config["reset_optim"] == 'none':
# Load everything from the checkpoint.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
elif config["reset_optim"] == 'all':
# Build everything from scratch.
pass
elif config["reset_optim"] == 'states':
# Reset optimizer, keep options.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
del optim_state_dict['optimizer']
elif config["reset_optim"] == 'keep_states':
# Reset options, keep optimizer.
optim_state_dict = ckpt_state_dict
learning_rates = [
optim_opt["learning_rate_bert"],
optim_opt["learning_rate_gnn"]
]
decay_fn = [
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
]
optimizer = cls(
build_torch_optimizer(model, optim_opt),
learning_rates,
learning_rate_decay_fn=decay_fn,
max_grad_norm=optim_opt["max_grad_norm"])
if optim_state_dict:
optimizer.load_state_dict(optim_state_dict)
return optimizer
@property
def training_step(self):
"""The current training step."""
return self._training_step
def learning_rate(self):
"""Returns the current learning rate."""
if self._learning_rate_decay_fn is None:
return self._learning_rate
return [decay_fn(self._decay_step) * learning_rate \
for decay_fn, learning_rate in \
zip(self._learning_rate_decay_fn, self._learning_rate)]
def state_dict(self):
return {
'training_step': self._training_step,
'decay_step': self._decay_step,
'optimizer': self._optimizer.state_dict()
}
def load_state_dict(self, state_dict):
self._training_step = state_dict['training_step']
# State can be partially restored.
if 'decay_step' in state_dict:
self._decay_step = state_dict['decay_step']
if 'optimizer' in state_dict:
self._optimizer.load_state_dict(state_dict['optimizer'])
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
def backward(self, loss):
"""Wrapper for backward pass. Some optimizer requires ownership of the
backward pass."""
loss.backward()
def step(self):
"""Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning
rate.
"""
learning_rate = self.learning_rate()
if isinstance(self._optimizer, MultipleOptimizer):
optimizers = self._optimizer.optimizers
else:
optimizers = [self._optimizer]
for lr, op in zip(learning_rate, optimizers):
for group in op.param_groups:
group['lr'] = lr
if self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm)
self._optimizer.step()
self._decay_step += 1
self._training_step += 1