initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
185
main_stage_3.py
Normal file
185
main_stage_3.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
|
||||
import argparse
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
# from transformers import BartTokenizer
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from utils.init import initialize_from_env
|
||||
# from datasets.pretraining import load_datasets, VideoTextRetDataset
|
||||
# from datasets.utils import get_datasets_media
|
||||
from models.setup import setup_model, setup_data, setup_data_test
|
||||
# from tasks.ft_avsd import ft_avsd, generate
|
||||
from tasks.stage_3 import ft_avsd, generate, generate_nextqa, generate_visdial
|
||||
|
||||
parser = argparse.ArgumentParser(description='Main script for v2dial')
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='v2dial/stage_3',
|
||||
help='model name to train or test')
|
||||
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
default='generate',
|
||||
help='train, generate or debug'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--eval_dir',
|
||||
type=str,
|
||||
default='/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/logs/stage_3/v2dial-google_flan-t5-large-finetune_without_stc_stm_only_visdial'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--wandb_mode',
|
||||
type=str,
|
||||
default='online',
|
||||
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--wandb_project',
|
||||
type=str,
|
||||
default='V2Dial'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--tag',
|
||||
type=str,
|
||||
default="finetuned_visdial_without_stm_stc",
|
||||
# default='V2dial-bart_large-Experts_from_scratch-gen-modalityLayers_4-without_residuals-AVSD',
|
||||
# default='Q_base_bart_base_from_modality_experts_c3m_webvid2mToVisdialToAVSD_num_hist3_with_fc_embed',
|
||||
# default='like_mst_mixer_Q_base_bart_large_from_modality_experts_c3m_webvid2mToavsd_12_frames_without_temp_fp16',
|
||||
# default='from_stage1_after_avsd_only_visdial_4_frames_10_rounds_ft',
|
||||
# default='from_scratch_visdial',
|
||||
# default='no_moes_div_st_from_scratch_only_avsd_4_frames_3_rounds_ft_fp16',
|
||||
# default='flant5_large_bert_experts_4_only_gen_AVSD_24epochs',
|
||||
help="Tag to differentiate the models"
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# '--medium',
|
||||
# type=str,
|
||||
# default='avsd',
|
||||
# help="Medium of the test dataset"
|
||||
# )
|
||||
|
||||
parser.add_argument(
|
||||
'--start_idx_gen',
|
||||
type=int,
|
||||
default=0,
|
||||
help="The start index for generation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--end_idx_gen',
|
||||
type=int,
|
||||
default=10,
|
||||
help="The end index for generation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--gen_subset_num',
|
||||
type=int,
|
||||
default=1,
|
||||
help="The index of the test split for generation"
|
||||
)
|
||||
|
||||
parser.add_argument('--ssh', action='store_true',
|
||||
help='whether or not we are executing command via ssh. '
|
||||
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||
|
||||
|
||||
def main(gpu, config, args):
|
||||
|
||||
config['gpu'] = gpu
|
||||
if config['distributed']:
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
world_size=config['num_gpus'],
|
||||
rank=gpu
|
||||
)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
device = torch.device(f'cuda:{gpu}')
|
||||
if config.use_cpu:
|
||||
device = torch.device('cpu')
|
||||
config['device'] = device
|
||||
# model = V2Dial(config)
|
||||
|
||||
# config['num_training_steps'] = num_step_per_epoch * config['epochs']
|
||||
# config['num_warmup_steps'] = num_step_per_epoch * config['warmup_epochs']
|
||||
if config['training']:
|
||||
train_dataloaders, val_dataloaders = setup_data(config)
|
||||
|
||||
(
|
||||
model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, visdial_step, avsd_step, nextqa_step, config
|
||||
) = setup_model(config)
|
||||
|
||||
if config['training']:
|
||||
ft_avsd(
|
||||
model,
|
||||
model_without_ddp,
|
||||
train_dataloaders,
|
||||
val_dataloaders,
|
||||
optimizer,
|
||||
global_step,
|
||||
visdial_step,
|
||||
avsd_step,
|
||||
nextqa_step,
|
||||
scheduler,
|
||||
scaler,
|
||||
start_epoch,
|
||||
config
|
||||
)
|
||||
elif config['generating']:
|
||||
test_dataloader = setup_data_test(config)
|
||||
if config.media_test == 'avsd':
|
||||
generate(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||
if config.media_test == 'visdial':
|
||||
generate_visdial(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||
elif config.media_test == 'nextqa':
|
||||
generate_nextqa(model, test_dataloader, args.tag, config, gen_subset_num=args.gen_subset_num)
|
||||
|
||||
if config['distributed']:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
# initialization
|
||||
model, stage = args.model.split('/')
|
||||
config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag)
|
||||
config['wandb_enabled'] = args.wandb_mode == 'online'
|
||||
config['training'] = args.mode == 'train'
|
||||
config['generating'] = args.mode == 'generate'
|
||||
config['debugging'] = args.mode == 'debug'
|
||||
|
||||
config['wandb_mode'] = args.wandb_mode
|
||||
# config['medium'] = args.medium
|
||||
config['start_idx_gen'] = args.start_idx_gen
|
||||
config['end_idx_gen'] = args.end_idx_gen
|
||||
config['expert_permutation'] = None
|
||||
# config['expert_permutation'] = {
|
||||
# 'spatial': 'history',
|
||||
# 'temporal': 'temporal',
|
||||
# 'caption': 'caption',
|
||||
# 'history': 'spatial'
|
||||
# }
|
||||
|
||||
# config['wandb_project']
|
||||
# if config['accelerator'] == 'ddp':
|
||||
if config['num_gpus'] > 1:
|
||||
config['distributed'] = True
|
||||
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||
else:
|
||||
config['distributed'] = False
|
||||
main(0, config, args)
|
Loading…
Add table
Add a link
Reference in a new issue