225 lines
6.9 KiB
Python
225 lines
6.9 KiB
Python
from init_utils import (
|
|
load_runner,
|
|
load_dataset,
|
|
set_random_seed,
|
|
set_training_steps,
|
|
initialize_from_env,
|
|
set_log_file,
|
|
copy_file_to_log
|
|
)
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import pyhocon
|
|
import glog as log
|
|
import socket
|
|
import getpass
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
from transformers import BartTokenizer
|
|
|
|
from custom_datasets.avsd import get_dataset as get_avsd_dataset
|
|
from custom_datasets.nextqa import get_dataset as get_nextqa_dataset
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
default='mst_mixer/mixer',
|
|
help='model name to train or test')
|
|
|
|
parser.add_argument(
|
|
'--mode',
|
|
type=str,
|
|
default='train',
|
|
help='train, generate or debug'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--eval_dir',
|
|
type=str,
|
|
default='ckpt/avsd'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--wandb_project',
|
|
type=str,
|
|
default='mst_mixer'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--wandb_mode',
|
|
type=str,
|
|
default='offline',
|
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--tag',
|
|
type=str,
|
|
default='full_model',
|
|
help="Tag to differentiate the models"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--start_idx_gen',
|
|
type=int,
|
|
default=0,
|
|
help="The start index for generation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--end_idx_gen',
|
|
type=int,
|
|
default=10,
|
|
help="The end index for generation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--gen_subset_num',
|
|
type=int,
|
|
default=1,
|
|
help="The index of the test split for generation"
|
|
)
|
|
|
|
|
|
parser.add_argument('--ssh', action='store_true',
|
|
help='whether or not we are executing command via ssh. '
|
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
|
|
|
|
|
def main(gpu, config, args):
|
|
config['training'] = args.mode == 'train'
|
|
config['debugging'] = args.mode == 'debug'
|
|
config['generating'] = args.mode == 'generate'
|
|
config['wandb_project'] = args.wandb_project
|
|
config['wandb_mode'] = 'disabled'
|
|
if config['training']:
|
|
config['wandb_mode'] = args.wandb_mode
|
|
|
|
# When generating, only use 1 GPU
|
|
if config['generating']:
|
|
assert config['num_gpus'] == 1, 'When generating, only use 1 GPU!'
|
|
|
|
if config['parallel'] and config['dp_type'] != 'dp':
|
|
config['rank'] = gpu
|
|
dist.init_process_group(
|
|
backend='nccl',
|
|
# init_method='env://',
|
|
world_size=config['num_gpus'],
|
|
rank=gpu
|
|
)
|
|
config['display'] = gpu == 0
|
|
torch.cuda.set_device(gpu)
|
|
else:
|
|
config['display'] = True
|
|
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
|
|
config['num_workers'] = 0
|
|
|
|
# set logs
|
|
if config['training']:
|
|
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
|
|
set_log_file(log_file, file_only=args.ssh)
|
|
|
|
# print environment info
|
|
if config['display']:
|
|
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
|
|
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
|
|
log.info('Command line is: {}'.format(' '.join(sys.argv)))
|
|
|
|
if config['parallel'] and config['dp_type'] != 'dp':
|
|
log.info(f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
|
|
log.info(f"Running experiment: {args.model}")
|
|
log.info(f"Results saved to {config['log_dir']}")
|
|
|
|
# initialization
|
|
if config['display'] and config['training']:
|
|
copy_file_to_log(config['log_dir'])
|
|
set_random_seed(config['random_seed'])
|
|
|
|
device = torch.device(f"cuda:{gpu}")
|
|
if config["use_cpu"]:
|
|
device = torch.device("cpu")
|
|
config['device'] = device
|
|
|
|
# prepare datasets (train and validation)
|
|
dataset, dataset_eval = load_dataset(config)
|
|
|
|
# set training steps
|
|
if not config['generating'] or config['parallel']:
|
|
config = set_training_steps(config, len(dataset))
|
|
|
|
if config['display']:
|
|
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
|
|
|
|
# load runner
|
|
runner = load_runner(config, dataset.tokenizer, dataset.vocab_size)
|
|
|
|
# parallel
|
|
if config['parallel']:
|
|
if config['dp_type'] == 'ddp':
|
|
torch.cuda.set_device(gpu)
|
|
runner.model = runner.model.to(gpu)
|
|
runner.model = nn.parallel.DistributedDataParallel(
|
|
runner.model,
|
|
device_ids=[gpu],
|
|
output_device=gpu,
|
|
find_unused_parameters=True
|
|
)
|
|
else:
|
|
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
|
|
|
|
if config['training'] or config['debugging']:
|
|
ckpt_path = config.get('start_path', None)
|
|
runner.load_ckpt(ckpt_path=ckpt_path)
|
|
runner.train(dataset, dataset_eval)
|
|
|
|
elif config['generating']:
|
|
if config['loads_start_path']:
|
|
runner.load_ckpt(config['start_ckpt_for_generating'])
|
|
else:
|
|
runner.load_ckpt_best()
|
|
assert args.gen_subset_num > 0
|
|
# Load the data
|
|
if config['task'] == 'avsd':
|
|
# load the saved tokenizer
|
|
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
|
|
test_dataset, _ = get_avsd_dataset(config, 'test', tokenizer)
|
|
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
|
|
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
|
|
runner.generate(
|
|
test_dataset, args.tag, tokenizer, gen_subset_num=args.gen_subset_num
|
|
)
|
|
|
|
elif config['task'] == 'nextqa':
|
|
# load the saved tokenizer
|
|
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
|
|
test_dataset, app_feats, mot_feats = get_nextqa_dataset(config, 'test')
|
|
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
|
|
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
|
|
runner.generate(
|
|
test_dataset, app_feats, mot_feats, args.tag, tokenizer, args.start_idx_gen, args.end_idx_gen, gen_subset_num=args.gen_subset_num
|
|
)
|
|
else:
|
|
raise ValueError
|
|
|
|
if config['parallel']:
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parser.parse_args()
|
|
|
|
# initialization
|
|
model_type, model_name = args.model.split('/')
|
|
config = initialize_from_env(model_name, args.mode, model_type, args.eval_dir, tag=args.tag)
|
|
if config['num_gpus'] > 1:
|
|
config['parallel'] = True
|
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
|
else:
|
|
config['parallel'] = False
|
|
main(0, config, args)
|