import random from argparse import ArgumentParser import numpy as np import torch from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger from tom.model import GraphBC_T, GraphBCRNN torch.multiprocessing.set_sharing_strategy('file_system') parser = ArgumentParser() # program level args parser.add_argument('--seed', type=int, default=4) # data specific args parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/') parser.add_argument('--types', nargs='+', type=str, default=['preference', 'multi_agent', 'single_object', 'instrumental_action'], help='types of tasks used for training / validation') parser.add_argument('--train', type=int, default=1) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--model_type', type=str, default='graphbcrnn') # model specific args parser_model = ArgumentParser() parser_model = GraphBC_T.add_model_specific_args(parser_model) # parser_model = GraphBCRNN.add_model_specific_args(parser_model) # NOTE: here unfortunately you have to select manually the model # add all the available trainer options to argparse parser = Trainer.add_argparse_args(parser) # combine parsers parser_all = ArgumentParser(conflict_handler='resolve', parents=[parser, parser_model]) # parse args args = parser_all.parse_args() args.types = sorted(args.types) print(args) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # init model if args.model_type == 'graphbct': model = GraphBC_T(args) elif args.model_type == 'graphbcrnn': model = GraphBCRNN(args) else: raise NotImplementedError torch.autograd.set_detect_anomaly(True) logger = WandbLogger(project='bib') trainer = Trainer( gradient_clip_val=args.gradient_clip_val, gpus=args.gpus, auto_select_gpus=args.auto_select_gpus, track_grad_norm=args.track_grad_norm, check_val_every_n_epoch=args.check_val_every_n_epoch, max_epochs=args.max_epochs, accelerator=args.accelerator, resume_from_checkpoint=args.resume_from_checkpoint, stochastic_weight_avg=args.stochastic_weight_avg, num_sanity_val_steps=args.num_sanity_val_steps, logger=logger ) if args.train: trainer.fit(model) else: raise NotImplementedError