107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
|
from torch.optim.lr_scheduler import _LRScheduler
|
||
|
from torch.optim import AdamW
|
||
|
|
||
|
|
||
|
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||
|
""" Linear warmup and then linear decay.
|
||
|
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||
|
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||
|
"""
|
||
|
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||
|
self.warmup_steps = warmup_steps
|
||
|
self.t_total = t_total
|
||
|
self.min_lr = min_lr
|
||
|
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||
|
|
||
|
def get_lr(self):
|
||
|
step = self.last_epoch
|
||
|
if step < self.warmup_steps:
|
||
|
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||
|
else:
|
||
|
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||
|
|
||
|
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||
|
|
||
|
|
||
|
def init_optim(model, config):
|
||
|
encoder_params_with_weight_decay = []
|
||
|
encoder_params_without_weight_decay = []
|
||
|
decoder_params_with_weight_decay = []
|
||
|
decoder_params_without_weight_decay = []
|
||
|
other_params_with_weight_decay = []
|
||
|
other_params_without_weight_decay = []
|
||
|
|
||
|
exclude_from_weight_decay=['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||
|
|
||
|
# Our model shares (embedding) parameters between the encoder and decoder.
|
||
|
# We want to include such parameters only in one parameter group.
|
||
|
# So we keep track of the unique ids of each parameter.
|
||
|
params_ids = []
|
||
|
|
||
|
for module_name, module in model.named_children():
|
||
|
for param_name, param in module.named_parameters():
|
||
|
if id(param) not in params_ids:
|
||
|
params_ids.append(id(param))
|
||
|
else:
|
||
|
continue
|
||
|
if param.requires_grad:
|
||
|
if 'encoder' in param_name:
|
||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||
|
encoder_params_without_weight_decay.append(param)
|
||
|
else:
|
||
|
encoder_params_with_weight_decay.append(param)
|
||
|
|
||
|
elif 'decoder' in param_name:
|
||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||
|
decoder_params_without_weight_decay.append(param)
|
||
|
else:
|
||
|
decoder_params_with_weight_decay.append(param)
|
||
|
else:
|
||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||
|
other_params_without_weight_decay.append(param)
|
||
|
else:
|
||
|
other_params_with_weight_decay.append(param)
|
||
|
|
||
|
optimizer_grouped_parameters = [
|
||
|
{
|
||
|
'params': encoder_params_with_weight_decay,
|
||
|
'weight_decay': 0.01,
|
||
|
'lr': config['learning_rate_bart']
|
||
|
},
|
||
|
{
|
||
|
'params': encoder_params_without_weight_decay,
|
||
|
'weight_decay': 0.0,
|
||
|
'lr': config['learning_rate_bart']
|
||
|
},
|
||
|
{
|
||
|
'params': decoder_params_with_weight_decay,
|
||
|
'weight_decay': 0.01,
|
||
|
'lr': config['learning_rate_bart']
|
||
|
},
|
||
|
{
|
||
|
'params': decoder_params_without_weight_decay,
|
||
|
'weight_decay': 0.0,
|
||
|
'lr': config['learning_rate_bart']
|
||
|
},
|
||
|
{
|
||
|
'params': other_params_with_weight_decay,
|
||
|
'weight_decay': 0.01,
|
||
|
'lr': config['learning_rate_other']
|
||
|
},
|
||
|
{
|
||
|
'params': other_params_without_weight_decay,
|
||
|
'weight_decay': 0.0,
|
||
|
'lr': config['learning_rate_other']
|
||
|
},
|
||
|
]
|
||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_bart'])
|
||
|
|
||
|
scheduler = WarmupLinearScheduleNonZero(
|
||
|
optimizer,
|
||
|
warmup_steps=config['warmup_steps'],
|
||
|
t_total=config['train_steps'],
|
||
|
min_lr=config['min_lr']
|
||
|
)
|
||
|
|
||
|
return optimizer, scheduler
|