V2Dial/tasks/pre_train.py
2025-06-24 08:38:09 +02:00

413 lines
16 KiB
Python

import os
import datetime
import wandb
import torch
import pandas as pd
from time import time
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch.nn.utils.clip_grad import clip_grad_value_
from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts
from datasets.utils import get_datasets_media
from datasets.dataloader import MetaLoader
from utils.dist import is_main_process, get_rank, get_world_size
from utils.logger import setup_wandb, log_dict_to_wandb
from .retrieval_utils import evaluation_wrapper
import glog as logger
def run_epoch(
model,
train_dataloaders,
optimizer,
epoch,
global_step,
webvid_step,
cc3m_step,
device,
scheduler,
scaler,
config
):
model.train()
media_types = list(train_dataloaders.keys())
log_freq = config['log_freq']
# metric_logger = MetricLogger(delimiter=' ')
# metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}'))
# metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}"))
loss_names = ['loss_' + k for k in config['loss_dict'].keys()]
# for l in loss_names:
# for m in media_types:
# metric_logger.add_meter(
# f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}")
# )
# header = '{} | Epoch = {}'.format(config['stage'], epoch)
model_without_ddp = model
if config['distributed']:
model_without_ddp = model.module
for k in train_dataloaders:
train_dataloaders[k].sampler.set_epoch(epoch)
train_dataloader = MetaLoader(name2loader=train_dataloaders)
log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
log_text_template += '[Losses] mlm (x{}) = {:.4f} | vcc (x{}) = {:.4f} | vcm (x{}) = {:.4f} | stc (x{}) = {:.4f} | stm (x{}) = {:.4f}\n'
log_text_template += '[Other] lr = {:.4f} | temp = {:.4f} | eta = {}\n'
# iterator = metric_logger.log_every(train_dataloader, log_freq, header)
local_step = 0
for media_type, (vis, caption, neg_vis) in train_dataloader:
start = time()
# loss_dict = {}
vis = vis.to(device)
neg_vis = neg_vis.to(device)
# idx = idx.to(device)
with torch.cuda.amp.autocast(enabled=config.fp16):
loss_dict = model(vis, caption, neg_vis, media_type)
# loss_dict.update(losses)
loss = sum(loss_dict.values())
loss_accum_grad = loss / config.accum_grad_every
scaler.scale(loss_accum_grad).backward()
# Perfrom gradient clipping: unscale --> clip
if config['clip_grad_value'] > 0:
# scaler.unscale_(optimizer)
clip_grad_value_(model.parameters(), config.clip_grad_value)
if local_step % config.accum_grad_every == 0:
scaler.step(optimizer)
scaler.update()
# scheduler.step(epoch, global_step)
scheduler.step()
optimizer.zero_grad()
time_iter = time() - start
eta = (len(train_dataloader) - local_step - 1) * time_iter
eta = str(datetime.timedelta(seconds=eta))
# log
log_dict_webvid = {}
log_dict_cc3m = {}
log_dict_rest = {}
for loss_name in loss_names:
value = loss_dict[loss_name]
value = value if isinstance(value, float) else value.item()
# metric_logger.update(**{f"{media_type}/{loss_name}": value})
if media_type == "cc3m":
log_dict_cc3m[f"train/{media_type}/{loss_name}"] = value
else:
log_dict_webvid[f"train/{media_type}/{loss_name}"] = value
# metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# metric_logger.update(temperature=model_without_ddp.temp.item())
log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"]
log_dict_rest['train/other/temperature'] = model_without_ddp.temp.item()
if is_main_process() and global_step % log_freq == 0 and local_step % config.accum_grad_every == 0:
log_dict_rest['train/other/step'] = global_step
if media_type == 'cc3m':
log_dict_cc3m['train/cc3m/step'] = cc3m_step
log_text = log_text_template.format(
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type,
config.loss_dict['mlm'], log_dict_cc3m['train/cc3m/loss_mlm'],
config.loss_dict['vcc'], log_dict_cc3m['train/cc3m/loss_vcc'],
config.loss_dict['vcm'], log_dict_cc3m['train/cc3m/loss_vcm'],
config.loss_dict['stc'], log_dict_cc3m['train/cc3m/loss_stc'],
config.loss_dict['stm'], log_dict_cc3m['train/cc3m/loss_stc'],
log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta
)
logger.info(log_text)
if config['wandb_enabled']:
wandb.log(log_dict_rest)
wandb.log(log_dict_cc3m)
# log_text_template = '[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
# log_text_template += '[losses: mlm = {:.4f} | vcc = {:4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f}]\n'
# log_text_template += '[Other: lr = {:.4f} | temp = {:4f}]\n'
else:
log_dict_webvid['train/webvid/step'] = webvid_step
log_text = log_text_template.format(
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type,
config.loss_dict['mlm'], log_dict_webvid['train/webvid/loss_mlm'],
config.loss_dict['vcc'], log_dict_webvid['train/webvid/loss_vcc'],
config.loss_dict['vcm'], log_dict_webvid['train/webvid/loss_vcm'],
config.loss_dict['stc'], log_dict_webvid['train/webvid/loss_stc'],
config.loss_dict['stm'], log_dict_webvid['train/webvid/loss_stm'],
log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta
)
logger.info(log_text)
if config['wandb_enabled']:
wandb.log(log_dict_rest)
wandb.log(log_dict_webvid)
if media_type == "cc3m":
cc3m_step += 1
else:
webvid_step += 1
global_step += 1
local_step += 1
# gather the stats from all processes
# metric_logger.synchronize_between_processes()
# logger.info(f"Averaged stats: {metric_logger.global_avg()}")
return global_step, webvid_step, cc3m_step
def eval(model, val_dataloader, device, epoch, config):
model.eval()
log_text_template = '\n' + '-' * 25 + '\n[Val Epoch{}][Iter. {}/{}][Media-type {}]\n'
log_text_template += '[Losses] mlm = {:.4f} | vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} \n'
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
cum_loss_stc = 0
cum_loss_stm = 0
cum_loss_vcc = 0
cum_loss_vcm = 0
cum_loss_mlm = 0
cum_loss_tot = 0
val_step = 0
# val_dataloader = MetaLoader(name2loader=val_dataloaders)
media_type = val_dataloader.dataset.medium
if is_main_process():
start_time = time()
# for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
for vis, caption, neg_vis in val_dataloader:
# for vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
vis = vis.to(device)
neg_vis = neg_vis.to(device)
# idx = idx.to(device)
with torch.cuda.amp.autocast(enabled=config['fp16']):
with torch.no_grad():
# loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type)
# loss_dict = model(vis, caption, neg_vis, neg_caption, media_type, file, neg_file)
loss_dict = model(vis, caption, neg_vis, media_type)
loss = sum(loss_dict.values())
loss_stc = loss_dict['loss_stc']
loss_stm = loss_dict['loss_stm']
loss_vcc = loss_dict['loss_vcc']
loss_vcm = loss_dict['loss_vcm']
loss_mlm = loss_dict['loss_mlm']
if config['distributed']:
dist.all_reduce(loss, op=ReduceOp.AVG)
if config.loss_dict['stc'] != 0:
dist.all_reduce(loss_stc, op=ReduceOp.AVG)
if config.loss_dict['stm'] != 0:
dist.all_reduce(loss_stm, op=ReduceOp.AVG)
if config.loss_dict['vcc'] != 0:
dist.all_reduce(loss_vcc, op=ReduceOp.AVG)
if config.loss_dict['vcm'] != 0:
dist.all_reduce(loss_vcm, op=ReduceOp.AVG)
if config.loss_dict['mlm'] != 0:
dist.all_reduce(loss_mlm, op=ReduceOp.AVG)
if is_main_process():
cum_loss_tot += loss.item()
cum_loss_stc += loss_stc.item()
cum_loss_stm += loss_stm.item()
cum_loss_vcc += loss_vcc.item()
cum_loss_vcm += loss_vcm.item()
cum_loss_mlm += loss_mlm.item()
if val_step % config.log_freq == 0:
log_text = log_text_template.format(
epoch, val_step, len(val_dataloader), media_type,
loss_mlm, loss_vcc, loss_vcm, loss_stc, loss_stm)
# log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n'
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
# log_text = log_text_template.format(
# epoch, val_step, len(val_dataloader), media_type,
# loss_vcc, loss_vcm, loss_stc, loss_stm, 0,
# loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen
# )
logger.info(log_text)
# logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format(
# epoch, val_step, len(val_dataloader), gen_loss, loss
# ))
val_step += 1
if config['distributed']:
dist.barrier()
if is_main_process():
duration = time() - start_time
cum_loss_tot /= len(val_dataloader)
cum_loss_stc /= len(val_dataloader)
cum_loss_stm /= len(val_dataloader)
cum_loss_vcc /= len(val_dataloader)
cum_loss_vcm /= len(val_dataloader)
cum_loss_mlm /= len(val_dataloader)
# cum_loss_vhc /= len(val_dataloader)
# cum_loss_vhm /= len(val_dataloader)
# cum_loss_chc /= len(val_dataloader)
# cum_loss_chm /= len(val_dataloader)
# cum_loss_gen /= len(val_dataloader)
logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_total = {:.4f}'.format(
datetime.timedelta(seconds=int(duration)), cum_loss_tot
))
# logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format(
# datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot
# ))
loss_dict = {
'stc': cum_loss_stc,
'stm': cum_loss_stm,
'vcc': cum_loss_vcc,
'vcm': cum_loss_vcm,
# 'vhc': cum_loss_vhc,
# 'vhm': cum_loss_vhm,
# 'chc': cum_loss_chc,
# 'chm': cum_loss_chm,
'mlm': cum_loss_mlm,
# 'gen': cum_loss_gen,
'tot': cum_loss_tot
}
return loss_dict
def pre_train(
model,
model_without_ddp,
train_dataloaders,
val_dataloaders,
optimizer,
global_step,
webvid_step,
cc3m_step,
scheduler,
scaler,
start_epoch,
config
):
if is_main_process() and config['wandb_enabled']:
run = setup_wandb(config)
setup_seed(config['seed'] + get_rank())
device = torch.device('cuda:{}'.format(config['gpu']))
if is_main_process() and config['wandb_enabled']:
wandb.watch(model)
best = float('inf')
best_epoch = 0
logger.info('[INFO] Start training...')
start_time_all = time()
for epoch in range(start_epoch, config['epochs']):
if not config['evaluate']:
start_time_epoch = time()
global_step, webvid_step, cc3m_step = run_epoch(
model,
train_dataloaders,
optimizer,
epoch,
global_step,
webvid_step,
cc3m_step,
device,
scheduler,
scaler,
config
)
end_time_epoch = time()
epoch_time = end_time_epoch - start_time_epoch
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
logger.info(f'[INFO] Epoch took {epoch_time_str}')
if not config['debugging']:
with torch.cuda.amp.autocast(enabled=config['fp16']):
# # TODO
# eval_res = {}
# for val_name, val_loader in val_dataloaders_dict.items():
# res = evaluation_wrapper(
# model_without_ddp, val_loader, tokenizer, device, config, prefix=val_name
# )
# eval_res.update(res)
val_res = {}
for medium in val_dataloaders:
res = eval(
model,
val_dataloaders[medium],
device,
epoch,
config
)
val_res[medium] = res
if is_main_process():
# Average across all datasets
avg_val_res = average_dicts(val_res)
# log to wandb
if config.wandb_enabled:
for medium in val_res:
log_dict_val = {}
# log_dict_val[f'val/{medium}/step'] = epoch
for l in val_res[medium]:
log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l]
wandb.log(log_dict_val)
# for p, v in eval_res.items():
# log_dict_to_wandb(v, step=global_step, prefix=p)
if config.stop_key is not None and config.stop_key in avg_val_res:
cur_best = avg_val_res[config.stop_key]
else: # stop_key = None
cur_best = best - 1 # save the last as the best
# Don't save vit weights as they are frozen
state_dict = model_without_ddp.state_dict()
state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k}
save_obj = {
"model": state_dict,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict(),
"config": config,
"epoch": epoch,
"global_step": global_step,
}
torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth"))
if not config.evaluate and cur_best < best:
torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth"))
# eval_file = "eval_res_best.json"
# eval_res.to_json(os.path.join(config.log_dir, eval_file))
best = cur_best
if config.evaluate:
break
if config['distributed']:
dist.barrier()
total_time = time() - start_time_all
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f'[INFO] Training took {total_time_str}')
if is_main_process() and config['wandb_enabled']:
run.finish()