147 lines
3.4 KiB
Python
147 lines
3.4 KiB
Python
|
|
import argparse
|
|
import torch
|
|
|
|
import torch.multiprocessing as mp
|
|
import torch.distributed as dist
|
|
|
|
from utils.init import initialize_from_env
|
|
from models.setup import setup_model, setup_data
|
|
from tasks.stage_2 import train as train_stage_2
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Main script for v2dial')
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
default='v2dial/stage_2',
|
|
help='model name to train or test')
|
|
|
|
parser.add_argument(
|
|
'--mode',
|
|
type=str,
|
|
default='train',
|
|
help='train, generate or debug'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--eval_dir',
|
|
type=str,
|
|
default='eval_dir',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--wandb_mode',
|
|
type=str,
|
|
default='online',
|
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--wandb_project',
|
|
type=str,
|
|
default='V2Dial'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--tag',
|
|
type=str,
|
|
default='experiment_tag',
|
|
help="Tag to differentiate the models"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--medium',
|
|
type=str,
|
|
default='avsd',
|
|
help="Medium of the test dataset"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--start_idx_gen',
|
|
type=int,
|
|
default=0,
|
|
help="The start index for generation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--end_idx_gen',
|
|
type=int,
|
|
default=10,
|
|
help="The end index for generation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--gen_subset_num',
|
|
type=int,
|
|
default=1,
|
|
help="The index of the test split for generation"
|
|
)
|
|
|
|
parser.add_argument('--ssh', action='store_true',
|
|
help='whether or not we are executing command via ssh. '
|
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
|
|
|
|
|
def main(gpu, config, args):
|
|
|
|
config['gpu'] = gpu
|
|
if config['distributed']:
|
|
dist.init_process_group(
|
|
backend='nccl',
|
|
world_size=config['num_gpus'],
|
|
rank=gpu
|
|
)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
device = torch.device(f'cuda:{gpu}')
|
|
if config.use_cpu:
|
|
device = torch.device('cpu')
|
|
config['device'] = device
|
|
|
|
if config['training']:
|
|
train_dataloaders, val_dataloaders = setup_data(config)
|
|
|
|
(
|
|
model, model_without_ddp, optimizer, scheduler, scaler, start_epoch, global_step, config
|
|
) = setup_model(config)
|
|
|
|
if config['training']:
|
|
train_stage_2(
|
|
model,
|
|
model_without_ddp,
|
|
train_dataloaders,
|
|
val_dataloaders,
|
|
optimizer,
|
|
global_step,
|
|
scheduler,
|
|
scaler,
|
|
start_epoch,
|
|
config
|
|
)
|
|
|
|
if config['distributed']:
|
|
dist.destroy_process_group()
|
|
|
|
if __name__ == '__main__':
|
|
args = parser.parse_args()
|
|
|
|
# initialization
|
|
model, stage = args.model.split('/')
|
|
config = initialize_from_env(model, args.mode, stage, args.eval_dir, tag=args.tag)
|
|
config['wandb_enabled'] = args.wandb_mode == 'online'
|
|
config['training'] = args.mode == 'train'
|
|
config['generating'] = args.mode == 'generate'
|
|
config['debugging'] = args.mode == 'debug'
|
|
|
|
config['wandb_mode'] = args.wandb_mode
|
|
config['medium'] = args.medium
|
|
config['start_idx_gen'] = args.start_idx_gen
|
|
config['end_idx_gen'] = args.end_idx_gen
|
|
|
|
if config['num_gpus'] > 1:
|
|
config['distributed'] = True
|
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
|
else:
|
|
config['distributed'] = False
|
|
main(0, config, args)
|