MST-MIXER/main.py

226 lines
6.9 KiB
Python
Raw Normal View History

2024-07-08 11:41:28 +02:00
from init_utils import (
load_runner,
load_dataset,
set_random_seed,
set_training_steps,
initialize_from_env,
set_log_file,
copy_file_to_log
)
import os
import sys
import argparse
import pyhocon
import glog as log
import socket
import getpass
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.distributed as dist
from transformers import BartTokenizer
from custom_datasets.avsd import get_dataset as get_avsd_dataset
from custom_datasets.nextqa import get_dataset as get_nextqa_dataset
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
parser.add_argument(
'--model',
type=str,
default='mst_mixer/mixer',
help='model name to train or test')
parser.add_argument(
'--mode',
type=str,
default='train',
help='train, generate or debug'
)
parser.add_argument(
'--eval_dir',
type=str,
default='ckpt/avsd'
)
parser.add_argument(
'--wandb_project',
type=str,
default='mst_mixer'
)
parser.add_argument(
'--wandb_mode',
type=str,
default='offline',
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
)
parser.add_argument(
'--tag',
type=str,
default='full_model',
help="Tag to differentiate the models"
)
parser.add_argument(
'--start_idx_gen',
type=int,
default=0,
help="The start index for generation"
)
parser.add_argument(
'--end_idx_gen',
type=int,
default=10,
help="The end index for generation"
)
parser.add_argument(
'--gen_subset_num',
type=int,
default=1,
help="The index of the test split for generation"
)
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh. '
'If set to True, we will not log.info anything to screen and only redirect them to log file')
def main(gpu, config, args):
config['training'] = args.mode == 'train'
config['debugging'] = args.mode == 'debug'
config['generating'] = args.mode == 'generate'
config['wandb_project'] = args.wandb_project
config['wandb_mode'] = 'disabled'
if config['training']:
config['wandb_mode'] = args.wandb_mode
# When generating, only use 1 GPU
if config['generating']:
assert config['num_gpus'] == 1, 'When generating, only use 1 GPU!'
if config['parallel'] and config['dp_type'] != 'dp':
config['rank'] = gpu
dist.init_process_group(
backend='nccl',
# init_method='env://',
world_size=config['num_gpus'],
rank=gpu
)
config['display'] = gpu == 0
torch.cuda.set_device(gpu)
else:
config['display'] = True
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
config['num_workers'] = 0
# set logs
if config['training']:
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
set_log_file(log_file, file_only=args.ssh)
# print environment info
if config['display']:
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
log.info('Command line is: {}'.format(' '.join(sys.argv)))
if config['parallel'] and config['dp_type'] != 'dp':
log.info(f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
log.info(f"Running experiment: {args.model}")
log.info(f"Results saved to {config['log_dir']}")
# initialization
if config['display'] and config['training']:
copy_file_to_log(config['log_dir'])
set_random_seed(config['random_seed'])
device = torch.device(f"cuda:{gpu}")
if config["use_cpu"]:
device = torch.device("cpu")
config['device'] = device
# prepare datasets (train and validation)
dataset, dataset_eval = load_dataset(config)
# set training steps
if not config['generating'] or config['parallel']:
config = set_training_steps(config, len(dataset))
if config['display']:
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
# load runner
runner = load_runner(config, dataset.tokenizer, dataset.vocab_size)
# parallel
if config['parallel']:
if config['dp_type'] == 'ddp':
torch.cuda.set_device(gpu)
runner.model = runner.model.to(gpu)
runner.model = nn.parallel.DistributedDataParallel(
runner.model,
device_ids=[gpu],
output_device=gpu,
find_unused_parameters=True
)
else:
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
if config['training'] or config['debugging']:
ckpt_path = config.get('start_path', None)
runner.load_ckpt(ckpt_path=ckpt_path)
runner.train(dataset, dataset_eval)
elif config['generating']:
if config['loads_start_path']:
runner.load_ckpt(config['start_ckpt_for_generating'])
else:
runner.load_ckpt_best()
assert args.gen_subset_num > 0
# Load the data
if config['task'] == 'avsd':
# load the saved tokenizer
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
test_dataset, _ = get_avsd_dataset(config, 'test', tokenizer)
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
runner.generate(
test_dataset, args.tag, tokenizer, gen_subset_num=args.gen_subset_num
)
elif config['task'] == 'nextqa':
# load the saved tokenizer
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
test_dataset, app_feats, mot_feats = get_nextqa_dataset(config, 'test')
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
runner.generate(
test_dataset, app_feats, mot_feats, args.tag, tokenizer, args.start_idx_gen, args.end_idx_gen, gen_subset_num=args.gen_subset_num
)
else:
raise ValueError
if config['parallel']:
dist.destroy_process_group()
if __name__ == '__main__':
args = parser.parse_args()
# initialization
model_type, model_name = args.model.split('/')
config = initialize_from_env(model_name, args.mode, model_type, args.eval_dir, tag=args.tag)
if config['num_gpus'] > 1:
config['parallel'] = True
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
else:
config['parallel'] = False
main(0, config, args)