358 lines
14 KiB
Python
358 lines
14 KiB
Python
import copy
|
|
import os.path as osp
|
|
import glog as logger
|
|
|
|
import torch
|
|
from torch.utils.data import ConcatDataset
|
|
from models.backbones.beit.builder import interpolate_pos_embed_beit
|
|
from models.backbones.bert.tokenization_bert import BertTokenizer
|
|
from transformers import T5Tokenizer, BartTokenizer, LlamaTokenizer
|
|
from utils.optimizer import create_optimizer
|
|
from utils.scheduler import create_scheduler
|
|
from datasets.dataloader import load_dataloaders
|
|
from datasets.pretraining import load_datasets as load_datasets_stage_1
|
|
from datasets.visdial_dataset import load_visdial_dataset
|
|
from datasets.champagne_dataset import load_champagne_dataset
|
|
from datasets.nextqa_dataset import load_nextqa_dataset
|
|
from datasets.avsd_dataset import load_avsd_dataset
|
|
# from datasets.avsd_dataset_like_mixer import load_avsd_dataset
|
|
|
|
from processors.blip_processors import Blip2ImageTrainProcessor
|
|
from processors.blip_processors import BlipCaptionProcessor, BlipDialogProcessor
|
|
|
|
from utils.init import set_training_steps
|
|
# from models.v2dial import V2Dial, V2DialBase
|
|
from models.v2dial import V2DialBase, V2Dial, V2DialNoMoes
|
|
|
|
# from datasets.avsd_dataset import get_dataset, AVSDDataSet
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def setup_model(
|
|
config, has_decoder=False, pretrain=False, find_unused_parameters=True
|
|
):
|
|
logger.info("Creating model")
|
|
|
|
if config['stage'] == 'stage_1':
|
|
config = copy.deepcopy(config)
|
|
|
|
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
# model = V2DialBase(config=config, expert_tokenizer=tokenizer)
|
|
model = V2DialBase(config)
|
|
model = model.to(torch.device('cuda'))
|
|
model_without_ddp = model
|
|
optimizer = create_optimizer(config, model)
|
|
scheduler = create_scheduler(config, optimizer)
|
|
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
|
|
|
|
if config['distributed']:
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[config['gpu']],
|
|
find_unused_parameters=find_unused_parameters, # `False` for image-only task
|
|
)
|
|
|
|
start_epoch = 0
|
|
global_step = 0
|
|
webvid_step = 0
|
|
cc3m_step = 0
|
|
|
|
if osp.isfile(config['pretrained_path']):
|
|
logger.info(f"Loading checkpoint from {config['pretrained_path']}")
|
|
checkpoint = torch.load(config['pretrained_path'], map_location="cpu")
|
|
state_dict = checkpoint["model"]
|
|
|
|
if config.resume:
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
scaler.load_state_dict(checkpoint["scaler"])
|
|
start_epoch = checkpoint["epoch"] + 1
|
|
global_step = checkpoint["global_step"]
|
|
elif not pretrain: # downstream init from pretrained ckpt
|
|
|
|
# interpolate positional embeddings.
|
|
state_dict = interpolate_pos_embed_beit(state_dict, model_without_ddp)
|
|
|
|
|
|
#TODO Might need to update to match the MoEs
|
|
if not config.evaluate: # finetuning from a pretarined weights.
|
|
for key in list(state_dict.keys()):
|
|
if "bert" in key:
|
|
encoder_key = key.replace("bert.", "")
|
|
state_dict[encoder_key] = state_dict[key]
|
|
if not has_decoder:
|
|
del state_dict[key]
|
|
|
|
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
|
|
# only for generation tasks like VQA
|
|
if has_decoder and "text_encoder" in key:
|
|
if "layer" in key:
|
|
encoder_keys = key.split(".")
|
|
layer_num = int(encoder_keys[4])
|
|
if layer_num < config.model.text_encoder.fusion_layer:
|
|
del state_dict[key]
|
|
continue
|
|
else:
|
|
decoder_layer_num = layer_num - 9
|
|
encoder_keys[4] = str(decoder_layer_num)
|
|
encoder_key = ".".join(encoder_keys)
|
|
else:
|
|
encoder_key = key
|
|
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
|
|
state_dict[decoder_key] = state_dict[key]
|
|
del state_dict[key]
|
|
|
|
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
|
|
logger.info(msg)
|
|
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
|
|
else:
|
|
logger.warning("No pretrained checkpoint provided, training from scratch")
|
|
|
|
return (
|
|
model,
|
|
model_without_ddp,
|
|
optimizer,
|
|
scheduler,
|
|
scaler,
|
|
start_epoch,
|
|
global_step,
|
|
webvid_step,
|
|
cc3m_step,
|
|
config
|
|
)
|
|
else:
|
|
# config = copy.deepcopy(config)
|
|
# if config['use_original_feats']:
|
|
# model = AVSDBart(config)
|
|
# else:
|
|
# # model = V2Dial(config, tokenizer_experts, tokenizer_enc_dec)
|
|
# if config.use_moes:
|
|
model = V2Dial(config)
|
|
# else:
|
|
# model = V2DialNoMoes(config)
|
|
|
|
model = model.to(torch.device('cuda'))
|
|
model_without_ddp = model
|
|
|
|
optimizer = None
|
|
scheduler = None
|
|
scaler = None
|
|
|
|
start_epoch = 0
|
|
global_step = 0
|
|
if config['stage'] == 'stage_3':
|
|
visdial_step = 0
|
|
avsd_step = 0
|
|
nextqa_step = 0
|
|
|
|
ckpt_path = config.pretrained_path_resume if config.resume else config.pretrained_path_prev_stage
|
|
if config.generating:
|
|
ckpt_path = config.best_ckpt_path
|
|
|
|
if osp.isfile(ckpt_path):
|
|
logger.info(f"Loading checkpoint from {ckpt_path}")
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
|
state_dict = checkpoint["model"]
|
|
|
|
if config.resume:
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
scaler.load_state_dict(checkpoint["scaler"])
|
|
start_epoch = checkpoint["epoch"] + 1
|
|
global_step = checkpoint["global_step"]
|
|
if config['stage'] == 'stage_3':
|
|
visdial_step = checkpoint['visdial_step']
|
|
avsd_step = checkpoint['avsd_step']
|
|
next_step = checkpoint['nextqa_step']
|
|
|
|
|
|
if config['stage'] in ['stage_2', 'stage_3'] and config.use_moes:
|
|
# Init. the history expert erights with the caption expert weights
|
|
p_names = [
|
|
'moe_layers.{}.norm_hist.weight',
|
|
'moe_layers.{}.mlp_hist.fc1.weight',
|
|
'moe_layers.{}.mlp_hist.fc1.bias',
|
|
'moe_layers.{}.mlp_hist.fc2.weight',
|
|
'moe_layers.{}.mlp_hist.fc2.bias',
|
|
]
|
|
|
|
for moe_layer_idx in range(config.num_moe_modality_layers):
|
|
for p_name in p_names:
|
|
p_hist_name = p_name.format(moe_layer_idx)
|
|
if p_hist_name not in state_dict:
|
|
p_cap_name = p_hist_name.replace('hist', 'cap')
|
|
state_dict[p_hist_name] = state_dict[p_cap_name].clone()
|
|
|
|
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
|
|
logger.info(msg)
|
|
|
|
logger.info(f"Loaded checkpoint from {ckpt_path}")
|
|
else:
|
|
logger.warning("No pretrained checkpoint provided, training from scratch")
|
|
|
|
if config['training']:
|
|
optimizer = create_optimizer(config, model_without_ddp)
|
|
scheduler = create_scheduler(config, optimizer)
|
|
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
|
|
|
|
elif config['generating']:
|
|
model.llm.set_input_embeddings(model.text_embedding)
|
|
|
|
if config['distributed']:
|
|
|
|
static_graph=config.stage!='stage_1'
|
|
if len(config.media_train) > 0:
|
|
static_graph = False
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model_without_ddp,
|
|
device_ids=[config['gpu']],
|
|
find_unused_parameters=find_unused_parameters, # `False` for image-only task
|
|
static_graph=static_graph
|
|
)
|
|
|
|
if config['stage'] == 'stage_3':
|
|
return (
|
|
model,
|
|
model_without_ddp,
|
|
optimizer,
|
|
scheduler,
|
|
scaler,
|
|
start_epoch,
|
|
global_step,
|
|
visdial_step,
|
|
avsd_step,
|
|
nextqa_step,
|
|
config
|
|
)
|
|
return (
|
|
model,
|
|
model_without_ddp,
|
|
optimizer,
|
|
scheduler,
|
|
scaler,
|
|
start_epoch,
|
|
global_step,
|
|
config
|
|
)
|
|
|
|
|
|
def setup_data(config):
|
|
logger.info("[INFO] Creating datasets")
|
|
|
|
# define the processors
|
|
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
|
|
|
|
if config['stage'] == 'stage_1':
|
|
text_processor = BlipCaptionProcessor(max_words=config.max_cap_len)
|
|
|
|
if config['debugging']:
|
|
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
|
|
else:
|
|
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'train')
|
|
|
|
val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
|
|
|
|
# cc3m_dataset = ConcatDataset([train_datasets['cc3m'], val_datasets['cc3m']])
|
|
|
|
# webvid_dataset = ConcatDataset([train_datasets['webvid'], val_datasets['webvid']])
|
|
|
|
# train_datasets = [cc3m_dataset, webvid_dataset]
|
|
train_datasets = list(train_datasets.values())
|
|
val_datasets = list(val_datasets.values())
|
|
|
|
batch_sizes = [config['batch_size_cc3m'], config['batch_size_webvid']]
|
|
num_samples = [len(d) for d in train_datasets]
|
|
config = set_training_steps(config, num_samples, batch_sizes)
|
|
|
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
|
|
|
# val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'test')
|
|
|
|
# val_dataloader = load_dataloaders(config, val_datasets, 'test', output_dict=True)
|
|
|
|
if config['stage'] == 'stage_2':
|
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
|
train_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'train')]
|
|
val_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'val')]
|
|
batch_sizes = [config['batch_size_champagne']]
|
|
num_samples = [len(d) for d in train_datasets]
|
|
config = set_training_steps(config, num_samples, batch_sizes)
|
|
|
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
|
|
|
|
|
if config['stage'] == 'stage_3':
|
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
|
train_datasets = []
|
|
val_datasets = []
|
|
for medium in config['media_train']:
|
|
if medium == 'visdial':
|
|
load_dataset_fn = load_visdial_dataset
|
|
elif medium == 'avsd':
|
|
load_dataset_fn = load_avsd_dataset
|
|
elif medium == 'nextqa':
|
|
load_dataset_fn = load_nextqa_dataset
|
|
# elif medium == 'champagne':
|
|
# load_dataset_fn = load_champagne_dataset
|
|
|
|
train_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'train'))
|
|
|
|
for medium in config['media_val']:
|
|
if medium == 'visdial':
|
|
load_dataset_fn = load_visdial_dataset
|
|
elif medium == 'avsd':
|
|
load_dataset_fn = load_avsd_dataset
|
|
elif medium == 'nextqa':
|
|
load_dataset_fn = load_nextqa_dataset
|
|
# elif medium == 'champagne':
|
|
# load_dataset_fn = load_champagne_dataset
|
|
|
|
val_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'val'))
|
|
|
|
batch_sizes = [d.batch_size for d in train_datasets]
|
|
num_samples = [len(d) for d in train_datasets]
|
|
config = set_training_steps(config, num_samples, batch_sizes)
|
|
|
|
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
|
|
|
|
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
|
|
|
|
return train_dataloaders, val_dataloaders
|
|
|
|
|
|
def setup_data_test(config):
|
|
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
|
|
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
|
|
|
|
if config.media_test == 'visdial':
|
|
load_dataset_fn = load_visdial_dataset
|
|
elif config.media_test == 'avsd':
|
|
load_dataset_fn = load_avsd_dataset
|
|
elif config.media_test == 'nextqa':
|
|
load_dataset_fn = load_nextqa_dataset
|
|
test_dataset = load_dataset_fn(config, vis_processor, text_processor, 'test')
|
|
|
|
test_dataloader = DataLoader(
|
|
test_dataset, shuffle=False, batch_size=test_dataset.batch_size)
|
|
|
|
return test_dataloader
|
|
|
|
|
|
# def setup_data_test(config, args):
|
|
# tokenizer_experts = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
# tokenizer_enc_dec = None
|
|
# if config.enc_dec_family == 'flan_t5':
|
|
# tokenizer_enc_dec = T5Tokenizer.from_pretrained(config.enc_dec_name)
|
|
# elif config.enc_dec_family == 'bart':
|
|
# tokenizer_enc_dec = BartTokenizer.from_pretrained(config.enc_dec_name)
|
|
# if config['tie_embeddings']:
|
|
# tokenizer_experts = tokenizer_enc_dec
|
|
|
|
# if config['medium'] == 'avsd':
|
|
# test_dataset = AVSDDataSet(config, 'avsd', tokenizer_experts, tokenizer_enc_dec, 'test')
|
|
# test_dataloader = DataLoader(
|
|
# test_dataset, shuffle=False, batch_size=test_dataset.batch_size, collate_fn=test_dataset.collate_fn)
|
|
# return test_dataloader
|