IRENE/train_tom.py
2024-02-01 15:40:47 +01:00

75 lines
2.4 KiB
Python

import random
from argparse import ArgumentParser
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from tom.model import GraphBC_T, GraphBCRNN
torch.multiprocessing.set_sharing_strategy('file_system')
parser = ArgumentParser()
# program level args
parser.add_argument('--seed', type=int, default=4)
# data specific args
parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/')
parser.add_argument('--types', nargs='+', type=str,
default=['preference', 'multi_agent', 'single_object', 'instrumental_action'],
help='types of tasks used for training / validation')
parser.add_argument('--train', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--model_type', type=str, default='graphbcrnn')
# model specific args
parser_model = ArgumentParser()
parser_model = GraphBC_T.add_model_specific_args(parser_model)
# parser_model = GraphBCRNN.add_model_specific_args(parser_model)
# NOTE: here unfortunately you have to select manually the model
# add all the available trainer options to argparse
parser = Trainer.add_argparse_args(parser)
# combine parsers
parser_all = ArgumentParser(conflict_handler='resolve',
parents=[parser, parser_model])
# parse args
args = parser_all.parse_args()
args.types = sorted(args.types)
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# init model
if args.model_type == 'graphbct':
model = GraphBC_T(args)
elif args.model_type == 'graphbcrnn':
model = GraphBCRNN(args)
else:
raise NotImplementedError
torch.autograd.set_detect_anomaly(True)
logger = WandbLogger(project='bib')
trainer = Trainer(
gradient_clip_val=args.gradient_clip_val,
gpus=args.gpus,
auto_select_gpus=args.auto_select_gpus,
track_grad_norm=args.track_grad_norm,
check_val_every_n_epoch=args.check_val_every_n_epoch,
max_epochs=args.max_epochs,
accelerator=args.accelerator,
resume_from_checkpoint=args.resume_from_checkpoint,
stochastic_weight_avg=args.stochastic_weight_avg,
num_sanity_val_steps=args.num_sanity_val_steps,
logger=logger
)
if args.train:
trainer.fit(model)
else:
raise NotImplementedError