initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
413
tasks/pre_train.py
Normal file
413
tasks/pre_train.py
Normal file
|
@ -0,0 +1,413 @@
|
|||
import os
|
||||
import datetime
|
||||
import wandb
|
||||
import torch
|
||||
import pandas as pd
|
||||
from time import time
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from torch.nn.utils.clip_grad import clip_grad_value_
|
||||
from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts
|
||||
from datasets.utils import get_datasets_media
|
||||
from datasets.dataloader import MetaLoader
|
||||
from utils.dist import is_main_process, get_rank, get_world_size
|
||||
from utils.logger import setup_wandb, log_dict_to_wandb
|
||||
from .retrieval_utils import evaluation_wrapper
|
||||
import glog as logger
|
||||
|
||||
|
||||
def run_epoch(
|
||||
model,
|
||||
train_dataloaders,
|
||||
optimizer,
|
||||
epoch,
|
||||
global_step,
|
||||
webvid_step,
|
||||
cc3m_step,
|
||||
device,
|
||||
scheduler,
|
||||
scaler,
|
||||
config
|
||||
):
|
||||
model.train()
|
||||
media_types = list(train_dataloaders.keys())
|
||||
|
||||
log_freq = config['log_freq']
|
||||
# metric_logger = MetricLogger(delimiter=' ')
|
||||
# metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}'))
|
||||
# metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}"))
|
||||
|
||||
loss_names = ['loss_' + k for k in config['loss_dict'].keys()]
|
||||
# for l in loss_names:
|
||||
# for m in media_types:
|
||||
# metric_logger.add_meter(
|
||||
# f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}")
|
||||
# )
|
||||
|
||||
|
||||
# header = '{} | Epoch = {}'.format(config['stage'], epoch)
|
||||
|
||||
model_without_ddp = model
|
||||
if config['distributed']:
|
||||
model_without_ddp = model.module
|
||||
for k in train_dataloaders:
|
||||
train_dataloaders[k].sampler.set_epoch(epoch)
|
||||
|
||||
train_dataloader = MetaLoader(name2loader=train_dataloaders)
|
||||
|
||||
log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
|
||||
log_text_template += '[Losses] mlm (x{}) = {:.4f} | vcc (x{}) = {:.4f} | vcm (x{}) = {:.4f} | stc (x{}) = {:.4f} | stm (x{}) = {:.4f}\n'
|
||||
log_text_template += '[Other] lr = {:.4f} | temp = {:.4f} | eta = {}\n'
|
||||
|
||||
# iterator = metric_logger.log_every(train_dataloader, log_freq, header)
|
||||
local_step = 0
|
||||
for media_type, (vis, caption, neg_vis) in train_dataloader:
|
||||
start = time()
|
||||
# loss_dict = {}
|
||||
vis = vis.to(device)
|
||||
neg_vis = neg_vis.to(device)
|
||||
# idx = idx.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=config.fp16):
|
||||
loss_dict = model(vis, caption, neg_vis, media_type)
|
||||
# loss_dict.update(losses)
|
||||
loss = sum(loss_dict.values())
|
||||
loss_accum_grad = loss / config.accum_grad_every
|
||||
|
||||
scaler.scale(loss_accum_grad).backward()
|
||||
|
||||
# Perfrom gradient clipping: unscale --> clip
|
||||
if config['clip_grad_value'] > 0:
|
||||
# scaler.unscale_(optimizer)
|
||||
clip_grad_value_(model.parameters(), config.clip_grad_value)
|
||||
|
||||
if local_step % config.accum_grad_every == 0:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
# scheduler.step(epoch, global_step)
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
time_iter = time() - start
|
||||
eta = (len(train_dataloader) - local_step - 1) * time_iter
|
||||
eta = str(datetime.timedelta(seconds=eta))
|
||||
# log
|
||||
log_dict_webvid = {}
|
||||
log_dict_cc3m = {}
|
||||
log_dict_rest = {}
|
||||
for loss_name in loss_names:
|
||||
value = loss_dict[loss_name]
|
||||
value = value if isinstance(value, float) else value.item()
|
||||
# metric_logger.update(**{f"{media_type}/{loss_name}": value})
|
||||
if media_type == "cc3m":
|
||||
log_dict_cc3m[f"train/{media_type}/{loss_name}"] = value
|
||||
else:
|
||||
log_dict_webvid[f"train/{media_type}/{loss_name}"] = value
|
||||
|
||||
# metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
# metric_logger.update(temperature=model_without_ddp.temp.item())
|
||||
log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"]
|
||||
log_dict_rest['train/other/temperature'] = model_without_ddp.temp.item()
|
||||
|
||||
if is_main_process() and global_step % log_freq == 0 and local_step % config.accum_grad_every == 0:
|
||||
log_dict_rest['train/other/step'] = global_step
|
||||
if media_type == 'cc3m':
|
||||
log_dict_cc3m['train/cc3m/step'] = cc3m_step
|
||||
|
||||
log_text = log_text_template.format(
|
||||
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type,
|
||||
config.loss_dict['mlm'], log_dict_cc3m['train/cc3m/loss_mlm'],
|
||||
config.loss_dict['vcc'], log_dict_cc3m['train/cc3m/loss_vcc'],
|
||||
config.loss_dict['vcm'], log_dict_cc3m['train/cc3m/loss_vcm'],
|
||||
config.loss_dict['stc'], log_dict_cc3m['train/cc3m/loss_stc'],
|
||||
config.loss_dict['stm'], log_dict_cc3m['train/cc3m/loss_stc'],
|
||||
log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta
|
||||
)
|
||||
logger.info(log_text)
|
||||
|
||||
if config['wandb_enabled']:
|
||||
wandb.log(log_dict_rest)
|
||||
wandb.log(log_dict_cc3m)
|
||||
# log_text_template = '[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
|
||||
# log_text_template += '[losses: mlm = {:.4f} | vcc = {:4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f}]\n'
|
||||
# log_text_template += '[Other: lr = {:.4f} | temp = {:4f}]\n'
|
||||
|
||||
else:
|
||||
log_dict_webvid['train/webvid/step'] = webvid_step
|
||||
log_text = log_text_template.format(
|
||||
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type,
|
||||
config.loss_dict['mlm'], log_dict_webvid['train/webvid/loss_mlm'],
|
||||
config.loss_dict['vcc'], log_dict_webvid['train/webvid/loss_vcc'],
|
||||
config.loss_dict['vcm'], log_dict_webvid['train/webvid/loss_vcm'],
|
||||
config.loss_dict['stc'], log_dict_webvid['train/webvid/loss_stc'],
|
||||
config.loss_dict['stm'], log_dict_webvid['train/webvid/loss_stm'],
|
||||
log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], eta
|
||||
)
|
||||
logger.info(log_text)
|
||||
|
||||
if config['wandb_enabled']:
|
||||
wandb.log(log_dict_rest)
|
||||
wandb.log(log_dict_webvid)
|
||||
|
||||
|
||||
if media_type == "cc3m":
|
||||
cc3m_step += 1
|
||||
else:
|
||||
webvid_step += 1
|
||||
global_step += 1
|
||||
local_step += 1
|
||||
# gather the stats from all processes
|
||||
# metric_logger.synchronize_between_processes()
|
||||
# logger.info(f"Averaged stats: {metric_logger.global_avg()}")
|
||||
|
||||
return global_step, webvid_step, cc3m_step
|
||||
|
||||
|
||||
def eval(model, val_dataloader, device, epoch, config):
|
||||
|
||||
model.eval()
|
||||
|
||||
log_text_template = '\n' + '-' * 25 + '\n[Val Epoch{}][Iter. {}/{}][Media-type {}]\n'
|
||||
log_text_template += '[Losses] mlm = {:.4f} | vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} \n'
|
||||
|
||||
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
||||
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
||||
|
||||
cum_loss_stc = 0
|
||||
cum_loss_stm = 0
|
||||
cum_loss_vcc = 0
|
||||
cum_loss_vcm = 0
|
||||
cum_loss_mlm = 0
|
||||
cum_loss_tot = 0
|
||||
val_step = 0
|
||||
|
||||
# val_dataloader = MetaLoader(name2loader=val_dataloaders)
|
||||
media_type = val_dataloader.dataset.medium
|
||||
|
||||
if is_main_process():
|
||||
start_time = time()
|
||||
|
||||
# for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
|
||||
for vis, caption, neg_vis in val_dataloader:
|
||||
# for vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
|
||||
vis = vis.to(device)
|
||||
neg_vis = neg_vis.to(device)
|
||||
# idx = idx.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=config['fp16']):
|
||||
with torch.no_grad():
|
||||
# loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type)
|
||||
# loss_dict = model(vis, caption, neg_vis, neg_caption, media_type, file, neg_file)
|
||||
loss_dict = model(vis, caption, neg_vis, media_type)
|
||||
|
||||
loss = sum(loss_dict.values())
|
||||
loss_stc = loss_dict['loss_stc']
|
||||
loss_stm = loss_dict['loss_stm']
|
||||
loss_vcc = loss_dict['loss_vcc']
|
||||
loss_vcm = loss_dict['loss_vcm']
|
||||
loss_mlm = loss_dict['loss_mlm']
|
||||
|
||||
if config['distributed']:
|
||||
dist.all_reduce(loss, op=ReduceOp.AVG)
|
||||
if config.loss_dict['stc'] != 0:
|
||||
dist.all_reduce(loss_stc, op=ReduceOp.AVG)
|
||||
if config.loss_dict['stm'] != 0:
|
||||
dist.all_reduce(loss_stm, op=ReduceOp.AVG)
|
||||
if config.loss_dict['vcc'] != 0:
|
||||
dist.all_reduce(loss_vcc, op=ReduceOp.AVG)
|
||||
if config.loss_dict['vcm'] != 0:
|
||||
dist.all_reduce(loss_vcm, op=ReduceOp.AVG)
|
||||
if config.loss_dict['mlm'] != 0:
|
||||
dist.all_reduce(loss_mlm, op=ReduceOp.AVG)
|
||||
|
||||
if is_main_process():
|
||||
cum_loss_tot += loss.item()
|
||||
cum_loss_stc += loss_stc.item()
|
||||
cum_loss_stm += loss_stm.item()
|
||||
cum_loss_vcc += loss_vcc.item()
|
||||
cum_loss_vcm += loss_vcm.item()
|
||||
cum_loss_mlm += loss_mlm.item()
|
||||
|
||||
if val_step % config.log_freq == 0:
|
||||
log_text = log_text_template.format(
|
||||
epoch, val_step, len(val_dataloader), media_type,
|
||||
loss_mlm, loss_vcc, loss_vcm, loss_stc, loss_stm)
|
||||
# log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n'
|
||||
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
||||
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
||||
# log_text = log_text_template.format(
|
||||
# epoch, val_step, len(val_dataloader), media_type,
|
||||
# loss_vcc, loss_vcm, loss_stc, loss_stm, 0,
|
||||
# loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen
|
||||
# )
|
||||
|
||||
logger.info(log_text)
|
||||
# logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format(
|
||||
# epoch, val_step, len(val_dataloader), gen_loss, loss
|
||||
# ))
|
||||
val_step += 1
|
||||
|
||||
if config['distributed']:
|
||||
dist.barrier()
|
||||
|
||||
if is_main_process():
|
||||
duration = time() - start_time
|
||||
|
||||
cum_loss_tot /= len(val_dataloader)
|
||||
cum_loss_stc /= len(val_dataloader)
|
||||
cum_loss_stm /= len(val_dataloader)
|
||||
cum_loss_vcc /= len(val_dataloader)
|
||||
cum_loss_vcm /= len(val_dataloader)
|
||||
cum_loss_mlm /= len(val_dataloader)
|
||||
|
||||
# cum_loss_vhc /= len(val_dataloader)
|
||||
# cum_loss_vhm /= len(val_dataloader)
|
||||
# cum_loss_chc /= len(val_dataloader)
|
||||
# cum_loss_chm /= len(val_dataloader)
|
||||
# cum_loss_gen /= len(val_dataloader)
|
||||
logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_total = {:.4f}'.format(
|
||||
datetime.timedelta(seconds=int(duration)), cum_loss_tot
|
||||
))
|
||||
|
||||
# logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format(
|
||||
# datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot
|
||||
# ))
|
||||
|
||||
loss_dict = {
|
||||
'stc': cum_loss_stc,
|
||||
'stm': cum_loss_stm,
|
||||
'vcc': cum_loss_vcc,
|
||||
'vcm': cum_loss_vcm,
|
||||
# 'vhc': cum_loss_vhc,
|
||||
# 'vhm': cum_loss_vhm,
|
||||
# 'chc': cum_loss_chc,
|
||||
# 'chm': cum_loss_chm,
|
||||
'mlm': cum_loss_mlm,
|
||||
# 'gen': cum_loss_gen,
|
||||
'tot': cum_loss_tot
|
||||
}
|
||||
return loss_dict
|
||||
|
||||
|
||||
def pre_train(
|
||||
model,
|
||||
model_without_ddp,
|
||||
train_dataloaders,
|
||||
val_dataloaders,
|
||||
optimizer,
|
||||
global_step,
|
||||
webvid_step,
|
||||
cc3m_step,
|
||||
scheduler,
|
||||
scaler,
|
||||
start_epoch,
|
||||
config
|
||||
):
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
run = setup_wandb(config)
|
||||
setup_seed(config['seed'] + get_rank())
|
||||
device = torch.device('cuda:{}'.format(config['gpu']))
|
||||
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
wandb.watch(model)
|
||||
|
||||
best = float('inf')
|
||||
best_epoch = 0
|
||||
|
||||
logger.info('[INFO] Start training...')
|
||||
start_time_all = time()
|
||||
for epoch in range(start_epoch, config['epochs']):
|
||||
if not config['evaluate']:
|
||||
start_time_epoch = time()
|
||||
global_step, webvid_step, cc3m_step = run_epoch(
|
||||
model,
|
||||
train_dataloaders,
|
||||
optimizer,
|
||||
epoch,
|
||||
global_step,
|
||||
webvid_step,
|
||||
cc3m_step,
|
||||
device,
|
||||
scheduler,
|
||||
scaler,
|
||||
config
|
||||
)
|
||||
end_time_epoch = time()
|
||||
epoch_time = end_time_epoch - start_time_epoch
|
||||
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
|
||||
logger.info(f'[INFO] Epoch took {epoch_time_str}')
|
||||
|
||||
if not config['debugging']:
|
||||
with torch.cuda.amp.autocast(enabled=config['fp16']):
|
||||
# # TODO
|
||||
# eval_res = {}
|
||||
# for val_name, val_loader in val_dataloaders_dict.items():
|
||||
# res = evaluation_wrapper(
|
||||
# model_without_ddp, val_loader, tokenizer, device, config, prefix=val_name
|
||||
# )
|
||||
# eval_res.update(res)
|
||||
val_res = {}
|
||||
|
||||
for medium in val_dataloaders:
|
||||
res = eval(
|
||||
model,
|
||||
val_dataloaders[medium],
|
||||
device,
|
||||
epoch,
|
||||
config
|
||||
)
|
||||
val_res[medium] = res
|
||||
|
||||
if is_main_process():
|
||||
# Average across all datasets
|
||||
avg_val_res = average_dicts(val_res)
|
||||
# log to wandb
|
||||
if config.wandb_enabled:
|
||||
for medium in val_res:
|
||||
log_dict_val = {}
|
||||
# log_dict_val[f'val/{medium}/step'] = epoch
|
||||
for l in val_res[medium]:
|
||||
log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l]
|
||||
wandb.log(log_dict_val)
|
||||
# for p, v in eval_res.items():
|
||||
# log_dict_to_wandb(v, step=global_step, prefix=p)
|
||||
if config.stop_key is not None and config.stop_key in avg_val_res:
|
||||
cur_best = avg_val_res[config.stop_key]
|
||||
else: # stop_key = None
|
||||
cur_best = best - 1 # save the last as the best
|
||||
|
||||
# Don't save vit weights as they are frozen
|
||||
state_dict = model_without_ddp.state_dict()
|
||||
state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k}
|
||||
|
||||
save_obj = {
|
||||
"model": state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"scaler": scaler.state_dict(),
|
||||
"config": config,
|
||||
"epoch": epoch,
|
||||
"global_step": global_step,
|
||||
}
|
||||
torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth"))
|
||||
|
||||
if not config.evaluate and cur_best < best:
|
||||
torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth"))
|
||||
# eval_file = "eval_res_best.json"
|
||||
# eval_res.to_json(os.path.join(config.log_dir, eval_file))
|
||||
best = cur_best
|
||||
|
||||
if config.evaluate:
|
||||
break
|
||||
if config['distributed']:
|
||||
dist.barrier()
|
||||
|
||||
total_time = time() - start_time_all
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logger.info(f'[INFO] Training took {total_time_str}')
|
||||
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
run.finish()
|
||||
|
435
tasks/retrieval_utils.py
Normal file
435
tasks/retrieval_utils.py
Normal file
|
@ -0,0 +1,435 @@
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
from models.criteria import get_sim
|
||||
from utils.basic import MetricLogger
|
||||
from utils.dist import get_rank, get_world_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_text_feats(texts, max_txt_l, tokenizer, model, device):
|
||||
num_text = len(texts)
|
||||
text_bs = 256
|
||||
text_feats = []
|
||||
text_atts = []
|
||||
|
||||
for i in range(0, num_text, text_bs):
|
||||
text = texts[i : min(num_text, i + text_bs)]
|
||||
text_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_txt_l,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
text_feat = model.encode_text(text_input)[0]
|
||||
text_feats.append(text_feat)
|
||||
text_atts.append(text_input.attention_mask)
|
||||
|
||||
text_feats = torch.cat(text_feats, dim=0)
|
||||
text_atts = torch.cat(text_atts, dim=0)
|
||||
return text_feats, text_atts
|
||||
|
||||
|
||||
def extract_vision_feats(data_loader, model, device, config):
|
||||
image_feats_all = []
|
||||
pooled_image_feats_all = []
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "extracting image feats"
|
||||
iterator = metric_logger.log_every(data_loader, 100, header)
|
||||
media_type = data_loader.dataset.medium
|
||||
for vis, _ in iterator:
|
||||
vis = vis.to(device, non_blocking=True)
|
||||
vis_feat, pooled_vis_feat = model.get_vis_enc_for_eval(vis, media_type)
|
||||
# if config.evaluation.eval_frame_ensemble == "concat": # default
|
||||
# image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
|
||||
vis_feat = vis_feat.unsqueeze(1) # (bsz, 1, l, d)
|
||||
# else:
|
||||
# assert config.video_input.num_frames == 1, "only support single-frame"
|
||||
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
||||
if not config.eval_offload:
|
||||
image_feats_all.append(vis_feat.cpu())
|
||||
pooled_image_feats_all.append(pooled_vis_feat.cpu())
|
||||
else:
|
||||
image_feats_all.append(vis_feat)
|
||||
pooled_image_feats_all.append(pooled_vis_feat)
|
||||
|
||||
image_feats_all = torch.cat(image_feats_all, dim=0)
|
||||
|
||||
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
|
||||
return image_feats_all, pooled_image_feats_all
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""):
|
||||
with torch.cuda.amp.autocast(enabled=config.fp16):
|
||||
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation(
|
||||
model, data_loader, tokenizer, device, config
|
||||
)
|
||||
score_pairs = [
|
||||
(prefix + "/", i2t_x, t2i_x),
|
||||
(prefix + "_emb/", i2t_emb, t2i_emb),
|
||||
]
|
||||
res = dict()
|
||||
for name, i2t, t2i in score_pairs:
|
||||
if i2t is not None:
|
||||
txt2img_ids = data_loader.dataset.txt2vis
|
||||
img2txt_ids = data_loader.dataset.vis2txt
|
||||
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
|
||||
return res
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluation(model, data_loader, tokenizer, device, config):
|
||||
model.eval()
|
||||
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Evaluation:"
|
||||
dtype = torch.half if config.fp16 else torch.float
|
||||
media_type = data_loader.dataset.medium
|
||||
logger.info(f"Start evaluation for {media_type}")
|
||||
|
||||
logger.info("Computing dual encoder features...")
|
||||
start_time = time.time()
|
||||
|
||||
# this computes all features in each GPU
|
||||
texts = data_loader.dataset.text
|
||||
max_txt_l = config.max_cap_len
|
||||
|
||||
text_feats, text_atts = extract_text_feats(
|
||||
texts, max_txt_l, tokenizer, model, device
|
||||
) # (bsz, Lt, d), (bsz, Lt)
|
||||
|
||||
image_feats, pooled_image_feats = extract_vision_feats(
|
||||
data_loader, model, device, config
|
||||
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
|
||||
logger.info("Finished feature extraction")
|
||||
logger.info("Computing ITC scores [dot-product]")
|
||||
_pooled_image_feats = (
|
||||
pooled_image_feats.to(device, non_blocking=True)
|
||||
if config.eval_offload
|
||||
else pooled_image_feats
|
||||
)
|
||||
i2t_scores, t2i_scores = get_sim(
|
||||
model.vis_proj(_pooled_image_feats), model.cap_proj(text_feats[:, 0])
|
||||
)
|
||||
logger.info("Computing ITC scores [dot-product], done!")
|
||||
|
||||
num_images = len(data_loader.dataset.vis)
|
||||
i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to(
|
||||
device, torch.float, non_blocking=True
|
||||
)
|
||||
|
||||
# computes only part of the scores at each GPU, gather at the end
|
||||
logger.info("Rerank dual-encoder results with cross-encoder...")
|
||||
num_tasks = get_world_size()
|
||||
rank = get_rank()
|
||||
# only uses the part associated with the raw eval set
|
||||
# compute image2text #
|
||||
step = num_images // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(num_images, start + step)
|
||||
|
||||
text_encoder = model.get_expert_encoder('vis_cap_grounding')
|
||||
|
||||
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
|
||||
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
|
||||
|
||||
# generate score for each clip, and aggregate all clip scores for a video
|
||||
n_clip_per_video = 1
|
||||
# (
|
||||
# image_feats.shape[1] if not False else image_feats[0].shape[1]
|
||||
# )
|
||||
|
||||
# logger.info(
|
||||
# f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={'concat'}"
|
||||
# )
|
||||
for i, sims in enumerate(iterator):
|
||||
k = min(len(sims), config.eval_k_test)
|
||||
topk_sim, topk_idx = sims.topk(k=k, dim=0)
|
||||
|
||||
clip_scores = []
|
||||
for clip_idx in range(n_clip_per_video):
|
||||
# if config.deep_fusion:
|
||||
# encoder_output = [
|
||||
# feat[start + i, clip_idx].to(device, non_blocking=True)
|
||||
# for feat in image_feats
|
||||
# ]
|
||||
|
||||
# else:
|
||||
encoder_output = (
|
||||
image_feats[start + i, clip_idx].to(device, non_blocking=True)
|
||||
if config.eval_offload
|
||||
else image_feats[start + i, clip_idx]
|
||||
) # (#frm*Li, d)
|
||||
|
||||
""" original
|
||||
encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d)
|
||||
encoder_att = torch.ones(
|
||||
encoder_output.size()[:-1], dtype=torch.long
|
||||
).to(device, non_blocking=True)
|
||||
output = text_encoder(
|
||||
encoder_embeds=text_feats[topk_idx],
|
||||
attention_mask=text_atts[topk_idx],
|
||||
encoder_hidden_states=encoder_output,
|
||||
encoder_attention_mask=encoder_att,
|
||||
return_dict=True,
|
||||
mode="fusion"
|
||||
)
|
||||
|
||||
itm_embeds = output.last_hidden_state[:, 0]
|
||||
"""
|
||||
|
||||
# new
|
||||
bs = 32
|
||||
# bs = config.batch_size_test.video
|
||||
itm_embeds = []
|
||||
|
||||
# if config.deep_fusion:
|
||||
# encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
|
||||
# encoder_att = [
|
||||
# torch.ones(feat.size()[:-1], dtype=torch.long).to(
|
||||
# device, non_blocking=True
|
||||
# )
|
||||
# for feat in encoder_output
|
||||
# ]
|
||||
# else:
|
||||
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
|
||||
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
|
||||
for j in range(0, len(topk_idx), bs):
|
||||
output = text_encoder(
|
||||
encoder_embeds=text_feats[topk_idx[j : j + bs]],
|
||||
attention_mask=text_atts[topk_idx[j : j + bs]],
|
||||
encoder_hidden_states=encoder_output,
|
||||
encoder_attention_mask=encoder_att,
|
||||
return_dict=True,
|
||||
)
|
||||
batch_itm_embeds = output.last_hidden_state[:, 0]
|
||||
itm_embeds.append(batch_itm_embeds)
|
||||
itm_embeds = torch.cat(itm_embeds, dim=0)
|
||||
# end new
|
||||
|
||||
score = model.vcm_head(itm_embeds)[:, 1]
|
||||
clip_scores.append(score)
|
||||
|
||||
# if len(clip_scores) == 1:
|
||||
score = clip_scores[0]
|
||||
# else:
|
||||
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
||||
# clip_scores = torch.stack(clip_scores) # (#clips, k)
|
||||
# if config.evaluation.eval_frame_ensemble == "mean":
|
||||
# score = clip_scores.mean(0)
|
||||
# elif config.evaluation.eval_frame_ensemble == "max":
|
||||
# score = clip_scores.max(0)[0]
|
||||
# elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
|
||||
# score = torch.logsumexp(clip_scores, dim=0)
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
|
||||
# )
|
||||
|
||||
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
|
||||
|
||||
# compute text2image #
|
||||
num_text = len(data_loader.dataset.text)
|
||||
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.vis)), -100.0).to(
|
||||
device, torch.float, non_blocking=True
|
||||
)
|
||||
|
||||
step = num_text // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(num_text, start + step)
|
||||
|
||||
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
|
||||
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
|
||||
# generate score for each clip, and aggregate all clip scores for a video
|
||||
n_clip_per_video = 1
|
||||
# (
|
||||
# image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
|
||||
# )
|
||||
for i, sims in enumerate(iterator):
|
||||
k = min(len(sims), config.eval_k_test)
|
||||
topk_sim, topk_idx = sims.topk(k=k, dim=0)
|
||||
# topk_idx =
|
||||
clip_scores = []
|
||||
for clip_idx in range(n_clip_per_video):
|
||||
|
||||
"""old
|
||||
encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \
|
||||
if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx]
|
||||
encoder_att = torch.ones(
|
||||
encoder_output.size()[:-1], dtype=torch.long
|
||||
).to(device, non_blocking=True)
|
||||
output = text_encoder(
|
||||
encoder_embeds=text_feats[start+i].repeat(k, 1, 1),
|
||||
attention_mask=text_atts[start+i].repeat(k, 1),
|
||||
encoder_hidden_states=encoder_output,
|
||||
encoder_attention_mask=encoder_att,
|
||||
return_dict=True,
|
||||
mode="fusion"
|
||||
)
|
||||
|
||||
itm_embeds = output.last_hidden_state[:, 0]
|
||||
"""
|
||||
|
||||
# new
|
||||
bs = 32
|
||||
# bs = config.batch_size_test.video
|
||||
itm_embeds = []
|
||||
for j in range(0, len(topk_idx), bs):
|
||||
|
||||
# if config.deep_fusion:
|
||||
# encoder_output = [
|
||||
# feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True)
|
||||
# for feat in image_feats
|
||||
# ]
|
||||
# encoder_att = [
|
||||
# torch.ones(feat.size()[:-1], dtype=torch.long).to(
|
||||
# device, non_blocking=True
|
||||
# )
|
||||
# for feat in encoder_output
|
||||
# ]
|
||||
# else:
|
||||
encoder_output = (
|
||||
image_feats[topk_idx[j : j + bs], clip_idx].to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
if config.eval_offload
|
||||
else image_feats[topk_idx[j : j + bs], clip_idx]
|
||||
)
|
||||
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
|
||||
repeat_n = (
|
||||
encoder_output.shape[0]
|
||||
# if not config.deep_fusion
|
||||
# else encoder_output[0].shape[0]
|
||||
)
|
||||
output = text_encoder(
|
||||
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
|
||||
attention_mask=text_atts[start + i].repeat(repeat_n, 1),
|
||||
encoder_hidden_states=encoder_output,
|
||||
encoder_attention_mask=encoder_att,
|
||||
return_dict=True,
|
||||
# mode="fusion",
|
||||
)
|
||||
|
||||
batch_itm_embeds = output.last_hidden_state[:, 0]
|
||||
itm_embeds.append(batch_itm_embeds)
|
||||
|
||||
itm_embeds = torch.cat(itm_embeds, dim=0)
|
||||
# end new
|
||||
|
||||
score = model.vcm_head(itm_embeds)[:, 1]
|
||||
clip_scores.append(score)
|
||||
|
||||
# if len(clip_scores) == 1:
|
||||
score = clip_scores[0]
|
||||
# else:
|
||||
# assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
|
||||
# clip_scores = torch.stack(clip_scores) # (#clips, k)
|
||||
# if config.evaluation.eval_frame_ensemble == "mean":
|
||||
# score = clip_scores.mean(0)
|
||||
# elif config.evaluation.eval_frame_ensemble == "max":
|
||||
# score = clip_scores.max(0)[0]
|
||||
# elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
|
||||
# score = torch.logsumexp(clip_scores, dim=0)
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
|
||||
# )
|
||||
|
||||
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
|
||||
|
||||
if config.distributed:
|
||||
# gether across GPUs
|
||||
dist.barrier()
|
||||
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logger.info(f"Evaluation time {total_time_str}")
|
||||
|
||||
return (
|
||||
i2t_scores_x.cpu().numpy(),
|
||||
t2i_scores_x.cpu().numpy(),
|
||||
i2t_scores.cpu().numpy(),
|
||||
i2t_scores.T.cpu().numpy(),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
|
||||
# Images->Text
|
||||
ranks = np.zeros(scores_i2t.shape[0])
|
||||
for index, score in enumerate(scores_i2t):
|
||||
inds = np.argsort(score)[::-1]
|
||||
# Score
|
||||
gt_txt_ids = img2txt[index]
|
||||
if isinstance(gt_txt_ids, int):
|
||||
ranks[index] = np.where(inds == gt_txt_ids)[0][0]
|
||||
else:
|
||||
rank = 1e20
|
||||
for i in gt_txt_ids:
|
||||
tmp = np.where(inds == i)[0][0]
|
||||
if tmp < rank:
|
||||
rank = tmp
|
||||
ranks[index] = rank
|
||||
|
||||
# Compute metrics
|
||||
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
|
||||
# Text->Images
|
||||
ranks = np.zeros(scores_t2i.shape[0])
|
||||
|
||||
for index, score in enumerate(scores_t2i):
|
||||
inds = np.argsort(score)[::-1]
|
||||
gt_img_ids = txt2img[index]
|
||||
if isinstance(gt_img_ids, int):
|
||||
ranks[index] = np.where(inds == gt_img_ids)[0][0]
|
||||
else: # list, used in the case each caption has multiple GT images
|
||||
# Score
|
||||
rank = 1e20
|
||||
for i in gt_img_ids:
|
||||
tmp = np.where(inds == i)[0][0]
|
||||
if tmp < rank:
|
||||
rank = tmp
|
||||
ranks[index] = rank
|
||||
|
||||
# Compute metrics
|
||||
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
|
||||
tr_mean = (tr1 + tr5 + tr10) / 3
|
||||
ir_mean = (ir1 + ir5 + ir10) / 3
|
||||
r_mean = (tr_mean + ir_mean) / 2
|
||||
|
||||
eval_result = {
|
||||
"txt_r1": tr1,
|
||||
"txt_r5": tr5,
|
||||
"txt_r10": tr10,
|
||||
"txt_r_mean": tr_mean,
|
||||
"vis_r1": ir1,
|
||||
"vis_r5": ir5,
|
||||
"vis_r10": ir10,
|
||||
"vis_r_mean": ir_mean,
|
||||
"r_mean": r_mean,
|
||||
}
|
||||
eval_result = {k: round(v, 2) for k, v in eval_result.items()}
|
||||
return eval_result
|
373
tasks/stage_2.py
Normal file
373
tasks/stage_2.py
Normal file
|
@ -0,0 +1,373 @@
|
|||
import os
|
||||
import datetime
|
||||
import wandb
|
||||
import torch
|
||||
import pandas as pd
|
||||
from time import time
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from torch.nn.utils.clip_grad import clip_grad_value_
|
||||
from utils.basic import MetricLogger, SmoothedValue, setup_seed, average_dicts
|
||||
from datasets.utils import get_datasets_media
|
||||
from datasets.dataloader import MetaLoader
|
||||
from utils.dist import is_main_process, get_rank, get_world_size
|
||||
from utils.logger import setup_wandb, log_dict_to_wandb
|
||||
from .retrieval_utils import evaluation_wrapper
|
||||
import glog as logger
|
||||
|
||||
|
||||
def run_epoch(
|
||||
model,
|
||||
train_dataloaders,
|
||||
optimizer,
|
||||
epoch,
|
||||
global_step,
|
||||
device,
|
||||
scheduler,
|
||||
scaler,
|
||||
config
|
||||
):
|
||||
model.train()
|
||||
media_types = list(train_dataloaders.keys())
|
||||
|
||||
log_freq = config['log_freq']
|
||||
# metric_logger = MetricLogger(delimiter=' ')
|
||||
# metric_logger.add_meter('lr', SmoothedValue(window=log_freq, fmt='{value:.6f}'))
|
||||
# metric_logger.add_meter("temperature", SmoothedValue(window=log_freq, fmt="{value:.4f}"))
|
||||
|
||||
loss_names = ['loss_' + k for k in config['loss_dict'].keys()]
|
||||
# for l in loss_names:
|
||||
# for m in media_types:
|
||||
# metric_logger.add_meter(
|
||||
# f'{m}/{l}', SmoothedValue(window=log_freq, fmt="{value:.4f}")
|
||||
# )
|
||||
|
||||
|
||||
# header = '{} | Epoch = {}'.format(config['stage'], epoch)
|
||||
|
||||
model_without_ddp = model
|
||||
if config['distributed']:
|
||||
model_without_ddp = model.module
|
||||
for k in train_dataloaders:
|
||||
train_dataloaders[k].sampler.set_epoch(epoch)
|
||||
|
||||
train_dataloader = MetaLoader(name2loader=train_dataloaders)
|
||||
|
||||
log_text_template = '\n' + '-' * 25 + '\n[Epoch {}/{}][Iter. {}/{}][Media-type {}]\n'
|
||||
log_text_template += '[Losses] gen = {:.4f} | vhc = {:.4f} | vhm = {:.4f} | stc = {:.4f} | stm = {:.4f}\n'
|
||||
log_text_template += '[Other] lr = {:.4f} | temp = {:.4f} | iter_time = {:.2f} | eta = {}\n'
|
||||
|
||||
# iterator = metric_logger.log_every(train_dataloader, log_freq, header)
|
||||
local_step = 0
|
||||
for media_type, (vis, caption, history, answer) in train_dataloader:
|
||||
# for media_type, (vis, caption, neg_vis, neg_caption, idx) in train_dataloader:
|
||||
|
||||
start = time()
|
||||
# loss_dict = {}
|
||||
vis = vis.to(device)
|
||||
# neg_vis = neg_vis.to(device)
|
||||
# idx = idx.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=config.fp16):
|
||||
loss_dict = model(vis, caption, history, answer, media_type)
|
||||
loss = sum(loss_dict.values())
|
||||
loss_accum_grad = loss / config.accum_grad_every
|
||||
|
||||
scaler.scale(loss_accum_grad).backward()
|
||||
|
||||
# Perfrom gradient clipping: unscale --> clip
|
||||
if config['clip_grad_value'] > 0:
|
||||
scaler.unscale_(optimizer)
|
||||
clip_grad_value_(model.parameters(), config.clip_grad_value)
|
||||
|
||||
if local_step % config.accum_grad_every == 0:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
time_iter = time() - start
|
||||
eta = (len(train_dataloader) - local_step - 1) * time_iter
|
||||
eta = str(datetime.timedelta(seconds=eta))
|
||||
# log
|
||||
log_dict = {}
|
||||
log_dict_rest = {}
|
||||
for loss_name in loss_names:
|
||||
value = loss_dict[loss_name]
|
||||
value = value if isinstance(value, float) else value.item()
|
||||
log_dict[f"train/{media_type}/{loss_name}"] = value
|
||||
|
||||
# metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
# metric_logger.update(temperature=model_without_ddp.temp.item())
|
||||
log_dict_rest['train/other/lr'] = optimizer.param_groups[0]["lr"]
|
||||
log_dict_rest['train/other/temperature'] = model_without_ddp.temp.item()
|
||||
|
||||
if is_main_process() and global_step % log_freq == 0 and local_step % config.accum_grad_every == 0:
|
||||
# log_dict['train/webvid/step'] = webvid_step
|
||||
log_text = log_text_template.format(
|
||||
epoch, config.epochs-1, local_step, len(train_dataloader) , media_type,
|
||||
log_dict['train/champagne/loss_gen'], log_dict['train/champagne/loss_vhc'], log_dict['train/champagne/loss_vhm'],
|
||||
log_dict['train/champagne/loss_stc'], log_dict['train/champagne/loss_stm'],
|
||||
log_dict_rest['train/other/lr'], log_dict_rest['train/other/temperature'], time_iter, eta
|
||||
)
|
||||
logger.info(log_text)
|
||||
log_dict_rest['train/other/step'] = global_step
|
||||
log_dict['train/champagne/step'] = global_step
|
||||
|
||||
if config['wandb_enabled']:
|
||||
wandb.log(log_dict)
|
||||
wandb.log(log_dict_rest)
|
||||
|
||||
global_step += 1
|
||||
local_step += 1
|
||||
# gather the stats from all processes
|
||||
# metric_logger.synchronize_between_processes()
|
||||
# logger.info(f"Averaged stats: {metric_logger.global_avg()}")
|
||||
|
||||
return global_step
|
||||
|
||||
|
||||
def eval(model, val_dataloader, device, epoch, config):
|
||||
|
||||
model.eval()
|
||||
|
||||
log_text_template = '\n' + '-' * 25 + '\n[Val Epoch{}][Iter. {}/{}][Media-type {}]\n'
|
||||
log_text_template += '[Losses] gen = {:.4f} | vhc = {:.4f} | vhm = {:.4f} | stc = {:.4f} | stm = {:.4f} \n'
|
||||
|
||||
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
||||
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
||||
|
||||
cum_loss_stc = 0
|
||||
cum_loss_stm = 0
|
||||
cum_loss_vhc = 0
|
||||
cum_loss_vhm = 0
|
||||
cum_loss_gen = 0
|
||||
cum_loss_tot = 0
|
||||
val_step = 0
|
||||
|
||||
# val_dataloader = MetaLoader(name2loader=val_dataloaders)
|
||||
media_type = val_dataloader.dataset.medium
|
||||
|
||||
if is_main_process():
|
||||
start_time = time()
|
||||
|
||||
# for vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
|
||||
for vis, caption, history, answer in val_dataloader:
|
||||
# for vis, cap_ids, hist_ids, label_ids, enc_dec_input_ids, idx, _ in val_dataloader:
|
||||
vis = vis.to(device)
|
||||
# neg_vis = neg_vis.to(device)
|
||||
# idx = idx.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=config['fp16']):
|
||||
with torch.no_grad():
|
||||
# loss_dict, _ = model(vis, cap_ids, hist_ids, ques_ids, label_ids, enc_dec_input_ids, media_type)
|
||||
loss_dict = model(vis, caption, history, answer, media_type)
|
||||
|
||||
loss = sum(loss_dict.values())
|
||||
loss_stc = loss_dict['loss_stc']
|
||||
loss_stm = loss_dict['loss_stm']
|
||||
loss_vhc = loss_dict['loss_vhc']
|
||||
loss_vhm = loss_dict['loss_vhm']
|
||||
loss_gen = loss_dict['loss_gen']
|
||||
|
||||
if config['distributed']:
|
||||
dist.all_reduce(loss, op=ReduceOp.AVG)
|
||||
if config.loss_dict['stc'] != 0:
|
||||
dist.all_reduce(loss_stc, op=ReduceOp.AVG)
|
||||
if config.loss_dict['stm'] != 0:
|
||||
dist.all_reduce(loss_stm, op=ReduceOp.AVG)
|
||||
if config.loss_dict['vhc'] != 0:
|
||||
dist.all_reduce(loss_vhc, op=ReduceOp.AVG)
|
||||
if config.loss_dict['vhm'] != 0:
|
||||
dist.all_reduce(loss_vhm, op=ReduceOp.AVG)
|
||||
if config.loss_dict['gen'] != 0:
|
||||
dist.all_reduce(loss_gen, op=ReduceOp.AVG)
|
||||
|
||||
if is_main_process():
|
||||
cum_loss_tot += loss.item()
|
||||
cum_loss_stc += loss_stc.item()
|
||||
cum_loss_stm += loss_stm.item()
|
||||
cum_loss_vhc += loss_vhc.item()
|
||||
cum_loss_vhm += loss_vhm.item()
|
||||
cum_loss_gen += loss_gen.item()
|
||||
|
||||
if val_step % config.log_freq == 0:
|
||||
log_text = log_text_template.format(
|
||||
epoch, val_step, len(val_dataloader), media_type,
|
||||
loss_gen, loss_vhc, loss_vhm, loss_stc, loss_stm)
|
||||
# log_text_template = '\n' + '-' * 25 + '\n[Val Eoch{}][Iter. {}/{}][Media-type {}]\n'
|
||||
# log_text_template += '[Losses] vcc = {:.4f} | vcm = {:.4f} | stc = {:.4f} | stm = {:.4f} | mlm = {:.4f} \n'
|
||||
# log_text_template += '[Losses] vhc = {:.4f} | vhm = {:.4f} | chc = {:.4f} | chm = {:.4f} | gen = {:.4f} \n'
|
||||
# log_text = log_text_template.format(
|
||||
# epoch, val_step, len(val_dataloader), media_type,
|
||||
# loss_vcc, loss_vcm, loss_stc, loss_stm, 0,
|
||||
# loss_vhc, loss_vhm, loss_chc, loss_chm, loss_gen
|
||||
# )
|
||||
|
||||
logger.info(log_text)
|
||||
# logger.info('[INFO] [Eval. Epoch {}][Iter. {}/{}][Losses] gen = {:.4f} | total = {:.4f}'.format(
|
||||
# epoch, val_step, len(val_dataloader), gen_loss, loss
|
||||
# ))
|
||||
val_step += 1
|
||||
|
||||
if config['distributed']:
|
||||
dist.barrier()
|
||||
|
||||
if is_main_process():
|
||||
duration = time() - start_time
|
||||
|
||||
cum_loss_tot /= len(val_dataloader)
|
||||
cum_loss_stc /= len(val_dataloader)
|
||||
cum_loss_stm /= len(val_dataloader)
|
||||
cum_loss_vhc /= len(val_dataloader)
|
||||
cum_loss_vhm /= len(val_dataloader)
|
||||
cum_loss_gen /= len(val_dataloader)
|
||||
|
||||
# cum_loss_vhc /= len(val_dataloader)
|
||||
# cum_loss_vhm /= len(val_dataloader)
|
||||
# cum_loss_chc /= len(val_dataloader)
|
||||
# cum_loss_chm /= len(val_dataloader)
|
||||
# cum_loss_gen /= len(val_dataloader)
|
||||
logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_total = {:.4f}'.format(
|
||||
datetime.timedelta(seconds=int(duration)), cum_loss_tot
|
||||
))
|
||||
|
||||
# logger.info('\n' + '-' * 25 + '\n' + 'Eval. took {}\n[Losses] cum_gen = {:.4f} | cum_total = {:.4f}'.format(
|
||||
# datetime.timedelta(seconds=int(duration)), cum_loss_gen, cum_loss_tot
|
||||
# ))
|
||||
|
||||
# switch back to training mode
|
||||
model.train()
|
||||
|
||||
loss_dict = {
|
||||
'stc': cum_loss_stc,
|
||||
'stm': cum_loss_stm,
|
||||
'vhc': cum_loss_vhc,
|
||||
'vhm': cum_loss_vhm,
|
||||
# 'vhc': cum_loss_vhc,
|
||||
# 'vhm': cum_loss_vhm,
|
||||
# 'chc': cum_loss_chc,
|
||||
# 'chm': cum_loss_chm,
|
||||
'gen': cum_loss_gen,
|
||||
# 'gen': cum_loss_gen,
|
||||
'tot': cum_loss_tot
|
||||
}
|
||||
return loss_dict
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
model_without_ddp,
|
||||
train_dataloaders,
|
||||
val_dataloaders,
|
||||
optimizer,
|
||||
global_step,
|
||||
scheduler,
|
||||
scaler,
|
||||
start_epoch,
|
||||
config
|
||||
):
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
run = setup_wandb(config)
|
||||
setup_seed(config['seed'] + get_rank())
|
||||
device = torch.device('cuda:{}'.format(config['gpu']))
|
||||
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
wandb.watch(model)
|
||||
|
||||
best = float('inf')
|
||||
best_epoch = 0
|
||||
|
||||
logger.info('[INFO] Start training...')
|
||||
start_time_all = time()
|
||||
for epoch in range(start_epoch, config['epochs']):
|
||||
if not config['evaluate']:
|
||||
start_time_epoch = time()
|
||||
global_step = run_epoch(
|
||||
model,
|
||||
train_dataloaders,
|
||||
optimizer,
|
||||
epoch,
|
||||
global_step,
|
||||
device,
|
||||
scheduler,
|
||||
scaler,
|
||||
config
|
||||
)
|
||||
end_time_epoch = time()
|
||||
epoch_time = end_time_epoch - start_time_epoch
|
||||
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
|
||||
logger.info(f'[INFO] Epoch took {epoch_time_str}')
|
||||
|
||||
if not config['debugging']:
|
||||
with torch.cuda.amp.autocast(enabled=config['fp16']):
|
||||
val_res = {}
|
||||
|
||||
for medium in val_dataloaders:
|
||||
res = eval(
|
||||
model,
|
||||
val_dataloaders[medium],
|
||||
device,
|
||||
epoch,
|
||||
config
|
||||
)
|
||||
val_res[medium] = res
|
||||
|
||||
|
||||
if is_main_process():
|
||||
# Average across all datasets
|
||||
avg_val_res = average_dicts(val_res)
|
||||
# log to wandb
|
||||
if config.wandb_enabled:
|
||||
for medium in val_res:
|
||||
log_dict_val = {}
|
||||
# log_dict_val[f'val/{medium}/step'] = epoch
|
||||
for l in val_res[medium]:
|
||||
log_dict_val[f'val/{medium}/{l}'] = val_res[medium][l]
|
||||
wandb.log(log_dict_val)
|
||||
# for p, v in eval_res.items():
|
||||
# log_dict_to_wandb(v, step=global_step, prefix=p)
|
||||
if config.stop_key is not None and config.stop_key in avg_val_res:
|
||||
cur_best = avg_val_res[config.stop_key]
|
||||
else: # stop_key = None
|
||||
cur_best = best - 1 # save the last as the best
|
||||
|
||||
# Don't save vit and llm weights as they are frozen
|
||||
state_dict = model_without_ddp.state_dict()
|
||||
if config.freeze_vit:
|
||||
state_dict = {k:v for k,v in state_dict.items() if 'visual_encoder' not in k}
|
||||
|
||||
if config.freeze_llm:
|
||||
state_dict = {k:v for k,v in state_dict.items() if 'llm' not in k}
|
||||
|
||||
save_obj = {
|
||||
"model": state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"scaler": scaler.state_dict(),
|
||||
"config": config,
|
||||
"epoch": epoch,
|
||||
"global_step": global_step,
|
||||
}
|
||||
torch.save(save_obj, os.path.join(config.log_dir, f"ckpt_{epoch:02d}.pth"))
|
||||
|
||||
if not config.evaluate and cur_best < best:
|
||||
torch.save(save_obj, os.path.join(config.log_dir, "ckpt_best.pth"))
|
||||
# eval_file = "eval_res_best.json"
|
||||
# eval_res.to_json(os.path.join(config.log_dir, eval_file))
|
||||
best = cur_best
|
||||
|
||||
if config.evaluate:
|
||||
break
|
||||
if config['distributed']:
|
||||
dist.barrier()
|
||||
|
||||
total_time = time() - start_time_all
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logger.info(f'[INFO] Training took {total_time_str}')
|
||||
|
||||
if is_main_process() and config['wandb_enabled']:
|
||||
run.finish()
|
||||
|
1051
tasks/stage_3.py
Normal file
1051
tasks/stage_3.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue