75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
import random
|
|
from argparse import ArgumentParser
|
|
import numpy as np
|
|
import torch
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
from tom.model import GraphBC_T, GraphBCRNN
|
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
|
|
|
parser = ArgumentParser()
|
|
|
|
# program level args
|
|
parser.add_argument('--seed', type=int, default=4)
|
|
# data specific args
|
|
parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/')
|
|
parser.add_argument('--types', nargs='+', type=str,
|
|
default=['preference', 'multi_agent', 'single_object', 'instrumental_action'],
|
|
help='types of tasks used for training / validation')
|
|
parser.add_argument('--train', type=int, default=1)
|
|
parser.add_argument('--num_workers', type=int, default=4)
|
|
parser.add_argument('--batch_size', type=int, default=16)
|
|
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
|
|
|
# model specific args
|
|
parser_model = ArgumentParser()
|
|
parser_model = GraphBC_T.add_model_specific_args(parser_model)
|
|
# parser_model = GraphBCRNN.add_model_specific_args(parser_model)
|
|
# NOTE: here unfortunately you have to select manually the model
|
|
|
|
# add all the available trainer options to argparse
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
# combine parsers
|
|
parser_all = ArgumentParser(conflict_handler='resolve',
|
|
parents=[parser, parser_model])
|
|
|
|
# parse args
|
|
args = parser_all.parse_args()
|
|
args.types = sorted(args.types)
|
|
print(args)
|
|
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
|
|
# init model
|
|
if args.model_type == 'graphbct':
|
|
model = GraphBC_T(args)
|
|
elif args.model_type == 'graphbcrnn':
|
|
model = GraphBCRNN(args)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
logger = WandbLogger(project='bib')
|
|
trainer = Trainer(
|
|
gradient_clip_val=args.gradient_clip_val,
|
|
gpus=args.gpus,
|
|
auto_select_gpus=args.auto_select_gpus,
|
|
track_grad_norm=args.track_grad_norm,
|
|
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
|
max_epochs=args.max_epochs,
|
|
accelerator=args.accelerator,
|
|
resume_from_checkpoint=args.resume_from_checkpoint,
|
|
stochastic_weight_avg=args.stochastic_weight_avg,
|
|
num_sanity_val_steps=args.num_sanity_val_steps,
|
|
logger=logger
|
|
)
|
|
|
|
if args.train:
|
|
trainer.fit(model)
|
|
else:
|
|
raise NotImplementedError
|