import torch import os import argparse import numpy as np import random import datetime import wandb from tqdm import tqdm from torch.utils.data import DataLoader import torch.nn as nn from torch.optim.lr_scheduler import CosineAnnealingLR from dataloader import Data from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D from models.tom_implicit import ImplicitToMnet from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL from models.tom_sl import SLToMnet from models.tom_tf import TFToMnet from models.single_mindnet import SingleMindNet from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim def train(args): if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single': flatten_dim = 2 else: flatten_dim = 1 train_dataset = Data( args.train_frame_path, args.label_path, args.train_pose_path, args.train_gaze_path, args.train_bbox_path, args.ocr_graph_path, presaved=args.presaved, flatten_dim=flatten_dim ) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=pad_collate, pin_memory=args.pin_memory ) val_dataset = Data( args.val_frame_path, args.label_path, args.val_pose_path, args.val_gaze_path, args.val_bbox_path, args.ocr_graph_path, presaved=args.presaved, flatten_dim=flatten_dim ) val_dataloader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=pad_collate, pin_memory=args.pin_memory ) device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') if args.model_type == 'resnet': inp_dim = get_input_dim(args.mods) model = ResNet(inp_dim, device).to(device) elif args.model_type == 'gru': inp_dim = get_input_dim(args.mods) model = ResNetGRU(inp_dim, device).to(device) elif args.model_type == 'lstm': inp_dim = get_input_dim(args.mods) model = ResNetLSTM(inp_dim, device).to(device) elif args.model_type == 'conv1d': inp_dim = get_input_dim(args.mods) model = ResNetConv1D(inp_dim, device).to(device) elif args.model_type == 'tom_cm': model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_cm_xl': model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_impl': model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_sl': model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_single': model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_tf': model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) else: raise NotImplementedError if args.resume_from_checkpoint is not None: model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device)) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.scheduler == None: scheduler = None else: scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5) if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1) else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device) stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}} max_val_classification_acc = 0 max_val_classification_epoch = None counter = 0 print(f'Number of parameters: {count_parameters(model)}') for i in range(args.num_epoch): # training print('Training for epoch {}/{}...'.format(i+1, args.num_epoch)) temp_train_classification_loss = [] epoch_num_correct = 0 epoch_cnt = 0 model.train() for j, batch in tqdm(enumerate(train_dataloader)): if args.use_mixup: frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27) else: frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking) if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking) if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking) if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking) if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking) pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs) pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) if args.use_mixup: labels = torch.reshape(labels, (-1, 54)).to(device) batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup( pred_left_labels, pred_right_labels, labels, sequence_lengths ) loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:]) else: labels = torch.reshape(labels, (-1, 2)).to(device) batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( pred_left_labels, pred_right_labels, labels, sequence_lengths ) loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1]) epoch_cnt += batch_num_pred epoch_num_correct += batch_num_correct temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2) optimizer.zero_grad() if args.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) loss.backward() optimizer.step() if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']}) print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format( i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc) ) if scheduler: scheduler.step() print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format( i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt) ) stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt) stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt) # validation print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch)) temp_val_classification_loss = [] epoch_num_correct = 0 epoch_cnt = 0 model.eval() with torch.no_grad(): for j, batch in tqdm(enumerate(val_dataloader)): frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking) if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking) if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking) if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking) if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking) pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs) pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) labels = torch.reshape(labels, (-1, 2)).to(device) batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( pred_left_labels, pred_right_labels, labels, sequence_lengths ) epoch_cnt += batch_num_pred epoch_num_correct += batch_num_correct loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1]) temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2) if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()}) print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format( i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc) ) print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format( i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt) ) cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt cls_acc = epoch_num_correct / epoch_cnt stats['val']['cls_loss'].append(cls_loss) stats['val']['cls_acc'].append(cls_acc) if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i}) # check for best stat/model using validation accuracy if stats['val']['cls_acc'][-1] >= max_val_classification_acc: max_val_classification_acc = stats['val']['cls_acc'][-1] max_val_classification_epoch = i+1 torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model')) counter = 0 else: counter += 1 print(f'EarlyStopping counter: {counter} out of {args.patience}.') if counter >= args.patience: break with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f: f.write('{}\n'.format(CFG)) f.write('{}\n'.format(stats)) f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc)) f.close() print(f'Results saved in {experiment_save_path}') if __name__ == '__main__': parser = argparse.ArgumentParser() # Define the command-line arguments parser.add_argument('--gpu_id', type=int) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--logger', action='store_true') parser.add_argument('--presaved', type=int, default=128) parser.add_argument('--clip_grad_norm', type=float, default=0.5) parser.add_argument('--use_mixup', action='store_true') parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False) parser.add_argument('--non_blocking', action='store_true') parser.add_argument('--patience', type=int, default=99) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--pin_memory', action='store_true') parser.add_argument('--num_epoch', type=int, default=300) parser.add_argument('--lr', type=float, default=4e-4) parser.add_argument('--scheduler', type=str, default=None) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--weight_decay', type=float, default=0.005) parser.add_argument('--label_smoothing', type=float, default=0.1) parser.add_argument('--model_type', type=str) parser.add_argument('--aggr', type=str, default='mult', required=False) parser.add_argument('--use_resnet', action='store_true') parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--tom_weight', type=float, default=2.0, required=False) parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox']) parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame') parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose') parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze') parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels') parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame') parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose') parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze') parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels') parser.add_argument('--ocr_graph_path', type=str, default='') parser.add_argument('--label_path', type=str, default='outfile') parser.add_argument('--save_path', type=str, default='experiments/') parser.add_argument('--resume_from_checkpoint', type=str, default=None) # Parse the command-line arguments args = parser.parse_args() if args.use_mixup and not args.mixup_alpha: parser.error("--use_mixup requires --mixup_alpha") if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': if not args.aggr: parser.error("The choosen --model_type requires --aggr") if args.model_type == 'tom_sl' and not args.tom_weight: parser.error("The choosen --model_type requires --tom_weight") # get experiment ID experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train' if not os.path.exists(args.save_path): os.makedirs(args.save_path, exist_ok=True) experiment_save_path = os.path.join(args.save_path, experiment_id) os.makedirs(experiment_save_path, exist_ok=True) CFG = { 'use_ocr_custom_loss': 0, 'presaved': args.presaved, 'batch_size': args.batch_size, 'num_epoch': args.num_epoch, 'lr': args.lr, 'scheduler': args.scheduler, 'weight_decay': args.weight_decay, 'model_type': args.model_type, 'use_resnet': args.use_resnet, 'hidden_dim': args.hidden_dim, 'tom_weight': args.tom_weight, 'dropout': args.dropout, 'label_smoothing': args.label_smoothing, 'clip_grad_norm': args.clip_grad_norm, 'use_mixup': args.use_mixup, 'mixup_alpha': args.mixup_alpha, 'non_blocking_tensors': args.non_blocking, 'patience': args.patience, 'pin_memory': args.pin_memory, 'resume_from_checkpoint': args.resume_from_checkpoint, 'aggr': args.aggr, 'mods': args.mods, 'save_path': experiment_save_path , 'seed': args.seed } print(CFG) print(f'Saving results in {experiment_save_path}') # set seed values if args.logger: wandb.init(project="boss", config=CFG) os.environ['PYTHONHASHSEED'] = str(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) train(args)