V2Dial/models/setup.py
2025-06-24 08:38:09 +02:00

358 lines
14 KiB
Python

import copy
import os.path as osp
import glog as logger
import torch
from torch.utils.data import ConcatDataset
from models.backbones.beit.builder import interpolate_pos_embed_beit
from models.backbones.bert.tokenization_bert import BertTokenizer
from transformers import T5Tokenizer, BartTokenizer, LlamaTokenizer
from utils.optimizer import create_optimizer
from utils.scheduler import create_scheduler
from datasets.dataloader import load_dataloaders
from datasets.pretraining import load_datasets as load_datasets_stage_1
from datasets.visdial_dataset import load_visdial_dataset
from datasets.champagne_dataset import load_champagne_dataset
from datasets.nextqa_dataset import load_nextqa_dataset
from datasets.avsd_dataset import load_avsd_dataset
# from datasets.avsd_dataset_like_mixer import load_avsd_dataset
from processors.blip_processors import Blip2ImageTrainProcessor
from processors.blip_processors import BlipCaptionProcessor, BlipDialogProcessor
from utils.init import set_training_steps
# from models.v2dial import V2Dial, V2DialBase
from models.v2dial import V2DialBase, V2Dial, V2DialNoMoes
# from datasets.avsd_dataset import get_dataset, AVSDDataSet
from torch.utils.data import DataLoader
def setup_model(
config, has_decoder=False, pretrain=False, find_unused_parameters=True
):
logger.info("Creating model")
if config['stage'] == 'stage_1':
config = copy.deepcopy(config)
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = V2DialBase(config=config, expert_tokenizer=tokenizer)
model = V2DialBase(config)
model = model.to(torch.device('cuda'))
model_without_ddp = model
optimizer = create_optimizer(config, model)
scheduler = create_scheduler(config, optimizer)
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
if config['distributed']:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[config['gpu']],
find_unused_parameters=find_unused_parameters, # `False` for image-only task
)
start_epoch = 0
global_step = 0
webvid_step = 0
cc3m_step = 0
if osp.isfile(config['pretrained_path']):
logger.info(f"Loading checkpoint from {config['pretrained_path']}")
checkpoint = torch.load(config['pretrained_path'], map_location="cpu")
state_dict = checkpoint["model"]
if config.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
elif not pretrain: # downstream init from pretrained ckpt
# interpolate positional embeddings.
state_dict = interpolate_pos_embed_beit(state_dict, model_without_ddp)
#TODO Might need to update to match the MoEs
if not config.evaluate: # finetuning from a pretarined weights.
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if not has_decoder:
del state_dict[key]
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
# only for generation tasks like VQA
if has_decoder and "text_encoder" in key:
if "layer" in key:
encoder_keys = key.split(".")
layer_num = int(encoder_keys[4])
if layer_num < config.model.text_encoder.fusion_layer:
del state_dict[key]
continue
else:
decoder_layer_num = layer_num - 9
encoder_keys[4] = str(decoder_layer_num)
encoder_key = ".".join(encoder_keys)
else:
encoder_key = key
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
state_dict[decoder_key] = state_dict[key]
del state_dict[key]
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
else:
logger.warning("No pretrained checkpoint provided, training from scratch")
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
webvid_step,
cc3m_step,
config
)
else:
# config = copy.deepcopy(config)
# if config['use_original_feats']:
# model = AVSDBart(config)
# else:
# # model = V2Dial(config, tokenizer_experts, tokenizer_enc_dec)
# if config.use_moes:
model = V2Dial(config)
# else:
# model = V2DialNoMoes(config)
model = model.to(torch.device('cuda'))
model_without_ddp = model
optimizer = None
scheduler = None
scaler = None
start_epoch = 0
global_step = 0
if config['stage'] == 'stage_3':
visdial_step = 0
avsd_step = 0
nextqa_step = 0
ckpt_path = config.pretrained_path_resume if config.resume else config.pretrained_path_prev_stage
if config.generating:
ckpt_path = config.best_ckpt_path
if osp.isfile(ckpt_path):
logger.info(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model"]
if config.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
if config['stage'] == 'stage_3':
visdial_step = checkpoint['visdial_step']
avsd_step = checkpoint['avsd_step']
next_step = checkpoint['nextqa_step']
if config['stage'] in ['stage_2', 'stage_3'] and config.use_moes:
# Init. the history expert erights with the caption expert weights
p_names = [
'moe_layers.{}.norm_hist.weight',
'moe_layers.{}.mlp_hist.fc1.weight',
'moe_layers.{}.mlp_hist.fc1.bias',
'moe_layers.{}.mlp_hist.fc2.weight',
'moe_layers.{}.mlp_hist.fc2.bias',
]
for moe_layer_idx in range(config.num_moe_modality_layers):
for p_name in p_names:
p_hist_name = p_name.format(moe_layer_idx)
if p_hist_name not in state_dict:
p_cap_name = p_hist_name.replace('hist', 'cap')
state_dict[p_hist_name] = state_dict[p_cap_name].clone()
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {ckpt_path}")
else:
logger.warning("No pretrained checkpoint provided, training from scratch")
if config['training']:
optimizer = create_optimizer(config, model_without_ddp)
scheduler = create_scheduler(config, optimizer)
scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)
elif config['generating']:
model.llm.set_input_embeddings(model.text_embedding)
if config['distributed']:
static_graph=config.stage!='stage_1'
if len(config.media_train) > 0:
static_graph = False
model = torch.nn.parallel.DistributedDataParallel(
model_without_ddp,
device_ids=[config['gpu']],
find_unused_parameters=find_unused_parameters, # `False` for image-only task
static_graph=static_graph
)
if config['stage'] == 'stage_3':
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
visdial_step,
avsd_step,
nextqa_step,
config
)
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
start_epoch,
global_step,
config
)
def setup_data(config):
logger.info("[INFO] Creating datasets")
# define the processors
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
if config['stage'] == 'stage_1':
text_processor = BlipCaptionProcessor(max_words=config.max_cap_len)
if config['debugging']:
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
else:
train_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'train')
val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'val')
# cc3m_dataset = ConcatDataset([train_datasets['cc3m'], val_datasets['cc3m']])
# webvid_dataset = ConcatDataset([train_datasets['webvid'], val_datasets['webvid']])
# train_datasets = [cc3m_dataset, webvid_dataset]
train_datasets = list(train_datasets.values())
val_datasets = list(val_datasets.values())
batch_sizes = [config['batch_size_cc3m'], config['batch_size_webvid']]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
# val_datasets = load_datasets_stage_1(config, vis_processor, text_processor, 'test')
# val_dataloader = load_dataloaders(config, val_datasets, 'test', output_dict=True)
if config['stage'] == 'stage_2':
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
train_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'train')]
val_datasets = [load_champagne_dataset(config, vis_processor, text_processor, 'val')]
batch_sizes = [config['batch_size_champagne']]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
if config['stage'] == 'stage_3':
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
train_datasets = []
val_datasets = []
for medium in config['media_train']:
if medium == 'visdial':
load_dataset_fn = load_visdial_dataset
elif medium == 'avsd':
load_dataset_fn = load_avsd_dataset
elif medium == 'nextqa':
load_dataset_fn = load_nextqa_dataset
# elif medium == 'champagne':
# load_dataset_fn = load_champagne_dataset
train_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'train'))
for medium in config['media_val']:
if medium == 'visdial':
load_dataset_fn = load_visdial_dataset
elif medium == 'avsd':
load_dataset_fn = load_avsd_dataset
elif medium == 'nextqa':
load_dataset_fn = load_nextqa_dataset
# elif medium == 'champagne':
# load_dataset_fn = load_champagne_dataset
val_datasets.append(load_dataset_fn(config, vis_processor, text_processor, 'val'))
batch_sizes = [d.batch_size for d in train_datasets]
num_samples = [len(d) for d in train_datasets]
config = set_training_steps(config, num_samples, batch_sizes)
train_dataloaders = load_dataloaders(config, train_datasets, 'train', output_dict=True)
val_dataloaders = load_dataloaders(config, val_datasets, 'val', output_dict=True)
return train_dataloaders, val_dataloaders
def setup_data_test(config):
vis_processor = Blip2ImageTrainProcessor(image_size=config.image_res)
text_processor = BlipDialogProcessor(max_words=config.max_text_len) # max_words = 50
if config.media_test == 'visdial':
load_dataset_fn = load_visdial_dataset
elif config.media_test == 'avsd':
load_dataset_fn = load_avsd_dataset
elif config.media_test == 'nextqa':
load_dataset_fn = load_nextqa_dataset
test_dataset = load_dataset_fn(config, vis_processor, text_processor, 'test')
test_dataloader = DataLoader(
test_dataset, shuffle=False, batch_size=test_dataset.batch_size)
return test_dataloader
# def setup_data_test(config, args):
# tokenizer_experts = BertTokenizer.from_pretrained('bert-base-uncased')
# tokenizer_enc_dec = None
# if config.enc_dec_family == 'flan_t5':
# tokenizer_enc_dec = T5Tokenizer.from_pretrained(config.enc_dec_name)
# elif config.enc_dec_family == 'bart':
# tokenizer_enc_dec = BartTokenizer.from_pretrained(config.enc_dec_name)
# if config['tie_embeddings']:
# tokenizer_experts = tokenizer_enc_dec
# if config['medium'] == 'avsd':
# test_dataset = AVSDDataSet(config, 'avsd', tokenizer_experts, tokenizer_enc_dec, 'test')
# test_dataloader = DataLoader(
# test_dataset, shuffle=False, batch_size=test_dataset.batch_size, collate_fn=test_dataset.collate_fn)
# return test_dataloader