390 lines
14 KiB
Python
390 lines
14 KiB
Python
|
import logging
|
||
|
import math
|
||
|
import numpy as np
|
||
|
import random
|
||
|
import functools
|
||
|
import glog as log
|
||
|
|
||
|
import torch
|
||
|
from torch import nn, optim
|
||
|
from torch.optim import Optimizer
|
||
|
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
|
||
|
import torch.nn.functional as F
|
||
|
from torch.nn.utils import clip_grad_norm_
|
||
|
from pytorch_transformers.optimization import AdamW
|
||
|
|
||
|
|
||
|
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||
|
""" Linear warmup and then linear decay.
|
||
|
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||
|
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||
|
"""
|
||
|
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||
|
self.warmup_steps = warmup_steps
|
||
|
self.t_total = t_total
|
||
|
self.min_lr = min_lr
|
||
|
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||
|
|
||
|
def get_lr(self):
|
||
|
step = self.last_epoch
|
||
|
if step < self.warmup_steps:
|
||
|
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||
|
else:
|
||
|
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||
|
|
||
|
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||
|
|
||
|
|
||
|
def init_optim(model, config):
|
||
|
optimizer_grouped_parameters = []
|
||
|
|
||
|
gnn_params = []
|
||
|
|
||
|
encoder_params_with_decay = []
|
||
|
encoder_params_without_decay = []
|
||
|
|
||
|
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||
|
|
||
|
for module_name, module in model.named_children():
|
||
|
for param_name, param in module.named_parameters():
|
||
|
if param.requires_grad:
|
||
|
if "gnn" in param_name:
|
||
|
gnn_params.append(param)
|
||
|
elif module_name == 'encoder':
|
||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||
|
encoder_params_without_decay.append(param)
|
||
|
else:
|
||
|
encoder_params_with_decay.append(param)
|
||
|
|
||
|
optimizer_grouped_parameters = [
|
||
|
{
|
||
|
'params': gnn_params,
|
||
|
'weight_decay': config.gnn_weight_decay,
|
||
|
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
|
||
|
}
|
||
|
]
|
||
|
|
||
|
optimizer_grouped_parameters.extend(
|
||
|
[
|
||
|
{
|
||
|
'params': encoder_params_without_decay,
|
||
|
'weight_decay': 0,
|
||
|
'lr': config['learning_rate_bert']
|
||
|
},
|
||
|
{
|
||
|
'params': encoder_params_with_decay,
|
||
|
'weight_decay': 0.01,
|
||
|
'lr': config['learning_rate_bert']
|
||
|
}
|
||
|
]
|
||
|
)
|
||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
|
||
|
scheduler = WarmupLinearScheduleNonZero(
|
||
|
optimizer,
|
||
|
warmup_steps=config['warmup_steps'],
|
||
|
t_total=config['train_steps'],
|
||
|
min_lr=config['min_lr']
|
||
|
)
|
||
|
|
||
|
return optimizer, scheduler
|
||
|
|
||
|
|
||
|
def build_torch_optimizer(model, config):
|
||
|
"""Builds the PyTorch optimizer.
|
||
|
|
||
|
We use the default parameters for Adam that are suggested by
|
||
|
the original paper https://arxiv.org/pdf/1412.6980.pdf
|
||
|
These values are also used by other established implementations,
|
||
|
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
|
||
|
https://keras.io/optimizers/
|
||
|
Recently there are slightly different values used in the paper
|
||
|
"Attention is all you need"
|
||
|
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
|
||
|
was used there however, beta2=0.999 is still arguably the more
|
||
|
established value, so we use that here as well
|
||
|
|
||
|
Args:
|
||
|
model: The model to optimize.
|
||
|
config: The dictionary of options.
|
||
|
|
||
|
Returns:
|
||
|
A ``torch.optim.Optimizer`` instance.
|
||
|
"""
|
||
|
params = [p for p in model.parameters() if p.requires_grad]
|
||
|
betas = [0.9, 0.999]
|
||
|
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||
|
|
||
|
params = {'bert': [], 'task': []}
|
||
|
for module_name, module in model.named_children():
|
||
|
if module_name == 'encoder':
|
||
|
param_type = 'bert'
|
||
|
else:
|
||
|
param_type = 'task'
|
||
|
for param_name, param in module.named_parameters():
|
||
|
if param.requires_grad:
|
||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||
|
params[param_type] += [
|
||
|
{
|
||
|
"params": [param],
|
||
|
"weight_decay": 0
|
||
|
}
|
||
|
]
|
||
|
else:
|
||
|
params[param_type] += [
|
||
|
{
|
||
|
"params": [param],
|
||
|
"weight_decay": 0.01
|
||
|
}
|
||
|
]
|
||
|
if config['task_optimizer'] == 'adamw':
|
||
|
log.info('Using AdamW as task optimizer')
|
||
|
task_optimizer = AdamWeightDecay(params['task'],
|
||
|
lr=config["learning_rate_task"],
|
||
|
betas=betas,
|
||
|
eps=1e-6)
|
||
|
elif config['task_optimizer'] == 'adam':
|
||
|
log.info('Using Adam as task optimizer')
|
||
|
task_optimizer = optim.Adam(params['task'],
|
||
|
lr=config["learning_rate_task"],
|
||
|
betas=betas,
|
||
|
eps=1e-6)
|
||
|
if len(params['bert']) > 0:
|
||
|
bert_optimizer = AdamWeightDecay(params['bert'],
|
||
|
lr=config["learning_rate_bert"],
|
||
|
betas=betas,
|
||
|
eps=1e-6)
|
||
|
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
|
||
|
else:
|
||
|
optimizer = task_optimizer
|
||
|
|
||
|
return optimizer
|
||
|
|
||
|
|
||
|
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
|
||
|
"""Returns the learning decay function from options."""
|
||
|
if decay_method == "linear":
|
||
|
return functools.partial(
|
||
|
linear_decay,
|
||
|
global_steps=train_steps,
|
||
|
**kwargs)
|
||
|
elif decay_method == "exp":
|
||
|
return functools.partial(
|
||
|
exp_decay,
|
||
|
global_steps=train_steps,
|
||
|
**kwargs)
|
||
|
else:
|
||
|
raise ValueError(f'{decay_method} not found')
|
||
|
|
||
|
|
||
|
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||
|
if step < warmup_steps:
|
||
|
return initial_learning_rate * step / warmup_steps
|
||
|
else:
|
||
|
return (initial_learning_rate - end_learning_rate) * \
|
||
|
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
|
||
|
end_learning_rate
|
||
|
|
||
|
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||
|
if step < warmup_steps:
|
||
|
return initial_learning_rate * step / warmup_steps
|
||
|
else:
|
||
|
return (initial_learning_rate - end_learning_rate) * \
|
||
|
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
|
||
|
end_learning_rate
|
||
|
|
||
|
|
||
|
class MultipleOptimizer(object):
|
||
|
""" Implement multiple optimizers needed for sparse adam """
|
||
|
|
||
|
def __init__(self, op):
|
||
|
""" ? """
|
||
|
self.optimizers = op
|
||
|
|
||
|
@property
|
||
|
def param_groups(self):
|
||
|
param_groups = []
|
||
|
for optimizer in self.optimizers:
|
||
|
param_groups.extend(optimizer.param_groups)
|
||
|
return param_groups
|
||
|
|
||
|
def zero_grad(self):
|
||
|
""" ? """
|
||
|
for op in self.optimizers:
|
||
|
op.zero_grad()
|
||
|
|
||
|
def step(self):
|
||
|
""" ? """
|
||
|
for op in self.optimizers:
|
||
|
op.step()
|
||
|
|
||
|
@property
|
||
|
def state(self):
|
||
|
""" ? """
|
||
|
return {k: v for op in self.optimizers for k, v in op.state.items()}
|
||
|
|
||
|
def state_dict(self):
|
||
|
""" ? """
|
||
|
return [op.state_dict() for op in self.optimizers]
|
||
|
|
||
|
def load_state_dict(self, state_dicts):
|
||
|
""" ? """
|
||
|
assert len(state_dicts) == len(self.optimizers)
|
||
|
for i in range(len(state_dicts)):
|
||
|
self.optimizers[i].load_state_dict(state_dicts[i])
|
||
|
|
||
|
|
||
|
class OptimizerBase(object):
|
||
|
"""
|
||
|
Controller class for optimization. Mostly a thin
|
||
|
wrapper for `optim`, but also useful for implementing
|
||
|
rate scheduling beyond what is currently available.
|
||
|
Also implements necessary methods for training RNNs such
|
||
|
as grad manipulations.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
optimizer,
|
||
|
learning_rate,
|
||
|
learning_rate_decay_fn=None,
|
||
|
max_grad_norm=None):
|
||
|
"""Initializes the controller.
|
||
|
|
||
|
Args:
|
||
|
optimizer: A ``torch.optim.Optimizer`` instance.
|
||
|
learning_rate: The initial learning rate.
|
||
|
learning_rate_decay_fn: An optional callable taking the current step
|
||
|
as argument and return a learning rate scaling factor.
|
||
|
max_grad_norm: Clip gradients to this global norm.
|
||
|
"""
|
||
|
self._optimizer = optimizer
|
||
|
self._learning_rate = learning_rate
|
||
|
self._learning_rate_decay_fn = learning_rate_decay_fn
|
||
|
self._max_grad_norm = max_grad_norm or 0
|
||
|
self._training_step = 1
|
||
|
self._decay_step = 1
|
||
|
|
||
|
@classmethod
|
||
|
def from_opt(cls, model, config, checkpoint=None):
|
||
|
"""Builds the optimizer from options.
|
||
|
|
||
|
Args:
|
||
|
cls: The ``Optimizer`` class to instantiate.
|
||
|
model: The model to optimize.
|
||
|
opt: The dict of user options.
|
||
|
checkpoint: An optional checkpoint to load states from.
|
||
|
|
||
|
Returns:
|
||
|
An ``Optimizer`` instance.
|
||
|
"""
|
||
|
optim_opt = config
|
||
|
optim_state_dict = None
|
||
|
|
||
|
if config["loads_ckpt"] and checkpoint is not None:
|
||
|
optim = checkpoint['optim']
|
||
|
ckpt_opt = checkpoint['opt']
|
||
|
ckpt_state_dict = {}
|
||
|
if isinstance(optim, Optimizer): # Backward compatibility.
|
||
|
ckpt_state_dict['training_step'] = optim._step + 1
|
||
|
ckpt_state_dict['decay_step'] = optim._step + 1
|
||
|
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
|
||
|
else:
|
||
|
ckpt_state_dict = optim
|
||
|
|
||
|
if config["reset_optim"] == 'none':
|
||
|
# Load everything from the checkpoint.
|
||
|
optim_opt = ckpt_opt
|
||
|
optim_state_dict = ckpt_state_dict
|
||
|
elif config["reset_optim"] == 'all':
|
||
|
# Build everything from scratch.
|
||
|
pass
|
||
|
elif config["reset_optim"] == 'states':
|
||
|
# Reset optimizer, keep options.
|
||
|
optim_opt = ckpt_opt
|
||
|
optim_state_dict = ckpt_state_dict
|
||
|
del optim_state_dict['optimizer']
|
||
|
elif config["reset_optim"] == 'keep_states':
|
||
|
# Reset options, keep optimizer.
|
||
|
optim_state_dict = ckpt_state_dict
|
||
|
|
||
|
learning_rates = [
|
||
|
optim_opt["learning_rate_bert"],
|
||
|
optim_opt["learning_rate_gnn"]
|
||
|
]
|
||
|
decay_fn = [
|
||
|
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
|
||
|
optim_opt['train_steps'],
|
||
|
warmup_steps=optim_opt['warmup_steps'],
|
||
|
decay_exp=optim_opt['decay_exp']),
|
||
|
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
|
||
|
optim_opt['train_steps'],
|
||
|
warmup_steps=optim_opt['warmup_steps'],
|
||
|
decay_exp=optim_opt['decay_exp']),
|
||
|
]
|
||
|
optimizer = cls(
|
||
|
build_torch_optimizer(model, optim_opt),
|
||
|
learning_rates,
|
||
|
learning_rate_decay_fn=decay_fn,
|
||
|
max_grad_norm=optim_opt["max_grad_norm"])
|
||
|
if optim_state_dict:
|
||
|
optimizer.load_state_dict(optim_state_dict)
|
||
|
return optimizer
|
||
|
|
||
|
@property
|
||
|
def training_step(self):
|
||
|
"""The current training step."""
|
||
|
return self._training_step
|
||
|
|
||
|
def learning_rate(self):
|
||
|
"""Returns the current learning rate."""
|
||
|
if self._learning_rate_decay_fn is None:
|
||
|
return self._learning_rate
|
||
|
return [decay_fn(self._decay_step) * learning_rate \
|
||
|
for decay_fn, learning_rate in \
|
||
|
zip(self._learning_rate_decay_fn, self._learning_rate)]
|
||
|
|
||
|
def state_dict(self):
|
||
|
return {
|
||
|
'training_step': self._training_step,
|
||
|
'decay_step': self._decay_step,
|
||
|
'optimizer': self._optimizer.state_dict()
|
||
|
}
|
||
|
|
||
|
def load_state_dict(self, state_dict):
|
||
|
self._training_step = state_dict['training_step']
|
||
|
# State can be partially restored.
|
||
|
if 'decay_step' in state_dict:
|
||
|
self._decay_step = state_dict['decay_step']
|
||
|
if 'optimizer' in state_dict:
|
||
|
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||
|
|
||
|
def zero_grad(self):
|
||
|
"""Zero the gradients of optimized parameters."""
|
||
|
self._optimizer.zero_grad()
|
||
|
|
||
|
def backward(self, loss):
|
||
|
"""Wrapper for backward pass. Some optimizer requires ownership of the
|
||
|
backward pass."""
|
||
|
loss.backward()
|
||
|
|
||
|
def step(self):
|
||
|
"""Update the model parameters based on current gradients.
|
||
|
|
||
|
Optionally, will employ gradient modification or update learning
|
||
|
rate.
|
||
|
"""
|
||
|
learning_rate = self.learning_rate()
|
||
|
|
||
|
if isinstance(self._optimizer, MultipleOptimizer):
|
||
|
optimizers = self._optimizer.optimizers
|
||
|
else:
|
||
|
optimizers = [self._optimizer]
|
||
|
for lr, op in zip(learning_rate, optimizers):
|
||
|
for group in op.param_groups:
|
||
|
group['lr'] = lr
|
||
|
if self._max_grad_norm > 0:
|
||
|
clip_grad_norm_(group['params'], self._max_grad_norm)
|
||
|
self._optimizer.step()
|
||
|
self._decay_step += 1
|
||
|
self._training_step += 1
|
||
|
|