import torch import os import argparse import numpy as np import random import datetime import wandb from tqdm import tqdm from torch.utils.data import DataLoader import torch.nn as nn from torch.optim.lr_scheduler import CosineAnnealingLR from tbd_dataloader import TBDDataset, collate_fn from models.common_mind import CommonMindToMnet from models.sl import SLToMnet from models.implicit import ImplicitToMnet from utils.helpers import count_parameters, compute_f1_scores def main(args): train_dataset = TBDDataset( path=args.data_path, mode="train", use_preprocessed_img=True ) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn ) val_dataset = TBDDataset( path=args.data_path, mode="val", use_preprocessed_img=True ) val_dataloader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn ) train_data = [data[1] for data in train_dataset.data] val_data = [data[1] for data in val_dataset.data] if args.logger: wandb.config.update({"train_data": train_data}) wandb.config.update({"val_data": val_data}) device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') # model if args.model_type == 'tom_cm': model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_sl': model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_impl': model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) else: raise NotImplementedError if args.resume_from_checkpoint is not None: model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device)) # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # scheduler if args.scheduler == None: scheduler = None else: scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5) # loss function if args.model_type == 'tom_sl': ce_loss_m1 = nn.NLLLoss() ce_loss_m2 = nn.NLLLoss() ce_loss_m12 = nn.NLLLoss() ce_loss_m21 = nn.NLLLoss() ce_loss_mc = nn.NLLLoss() else: ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) stats = { 'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}, 'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []} } max_val_f1 = 0 max_val_classification_epoch = None counter = 0 print(f'Number of parameters: {count_parameters(model)}') for i in range(args.num_epoch): # training print('Training for epoch {}/{}...'.format(i+1, args.num_epoch)) epoch_train_loss_m1 = 0.0 epoch_train_loss_m2 = 0.0 epoch_train_loss_m12 = 0.0 epoch_train_loss_m21 = 0.0 epoch_train_loss_mc = 0.0 m1_train_batch_pred_list = [] m2_train_batch_pred_list = [] m12_train_batch_pred_list = [] m21_train_batch_pred_list = [] mc_train_batch_pred_list = [] m1_train_batch_label_list = [] m2_train_batch_label_list = [] m12_train_batch_label_list = [] m21_train_batch_label_list = [] mc_train_batch_label_list = [] model.train() for j, batch in tqdm(enumerate(train_dataloader)): img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) m1_pred = m1_pred.reshape(-1, 4) m2_pred = m2_pred.reshape(-1, 4) m12_pred = m12_pred.reshape(-1, 4) m21_pred = m21_pred.reshape(-1, 4) mc_pred = mc_pred.reshape(-1, 4) m1_label = labels[:, 0].reshape(-1).to(device) m2_label = labels[:, 1].reshape(-1).to(device) m12_label = labels[:, 2].reshape(-1).to(device) m21_label = labels[:, 3].reshape(-1).to(device) mc_label = labels[:, 4].reshape(-1).to(device) loss_m1 = ce_loss_m1(m1_pred, m1_label) loss_m2 = ce_loss_m2(m2_pred, m2_label) loss_m12 = ce_loss_m12(m12_pred, m12_label) loss_m21 = ce_loss_m21(m21_pred, m21_label) loss_mc = ce_loss_mc(mc_pred, mc_label) loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc epoch_train_loss_m1 += loss_m1.data.item() epoch_train_loss_m2 += loss_m2.data.item() epoch_train_loss_m12 += loss_m12.data.item() epoch_train_loss_m21 += loss_m21.data.item() epoch_train_loss_mc += loss_mc.data.item() m1_train_batch_pred_list.append(m1_pred) m2_train_batch_pred_list.append(m2_pred) m12_train_batch_pred_list.append(m12_pred) m21_train_batch_pred_list.append(m21_pred) mc_train_batch_pred_list.append(mc_pred) m1_train_batch_label_list.append(m1_label) m2_train_batch_label_list.append(m2_label) m12_train_batch_label_list.append(m12_label) m21_train_batch_label_list.append(m21_label) mc_train_batch_label_list.append(mc_label) optimizer.zero_grad() if args.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) loss.backward() optimizer.step() if args.logger: wandb.log({ 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr'] }) print("Epoch {}/{} batch {}/{} training done with loss={}".format( i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item()) ) if scheduler: scheduler.step() train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores( torch.cat(m1_train_batch_pred_list), torch.cat(m1_train_batch_label_list), torch.cat(m2_train_batch_pred_list), torch.cat(m2_train_batch_label_list), torch.cat(m12_train_batch_pred_list), torch.cat(m12_train_batch_label_list), torch.cat(m21_train_batch_pred_list), torch.cat(m21_train_batch_label_list), torch.cat(mc_train_batch_pred_list), torch.cat(mc_train_batch_label_list) ) print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format( i+1, args.num_epoch, epoch_train_loss_m1/len(train_dataloader), epoch_train_loss_m2/len(train_dataloader), epoch_train_loss_m12/len(train_dataloader), epoch_train_loss_m21/len(train_dataloader), epoch_train_loss_mc/len(train_dataloader), train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score ) ) stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader)) stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader)) stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader)) stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader)) stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader)) stats['train']['m1_f1'].append(train_m1_f1_score) stats['train']['m2_f1'].append(train_m2_f1_score) stats['train']['m12_f1'].append(train_m12_f1_score) stats['train']['m21_f1'].append(train_m21_f1_score) stats['train']['mc_f1'].append(train_mc_f1_score) if args.logger: wandb.log( { 'train_m1_loss': epoch_train_loss_m1/len(train_dataloader), 'train_m2_loss': epoch_train_loss_m2/len(train_dataloader), 'train_m12_loss': epoch_train_loss_m12/len(train_dataloader), 'train_m21_loss': epoch_train_loss_m21/len(train_dataloader), 'train_mc_loss': epoch_train_loss_mc/len(train_dataloader), 'train_loss': epoch_train_loss_m1/len(train_dataloader) + \ epoch_train_loss_m2/len(train_dataloader) + \ epoch_train_loss_m12/len(train_dataloader) + \ epoch_train_loss_m21/len(train_dataloader) + \ epoch_train_loss_mc/len(train_dataloader), 'train_m1_f1_score': train_m1_f1_score, 'train_m2_f1_score': train_m2_f1_score, 'train_m12_f1_score': train_m12_f1_score, 'train_m21_f1_score': train_m21_f1_score, 'train_mc_f1_score': train_mc_f1_score } ) # validation print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch)) epoch_val_loss_m1 = 0.0 epoch_val_loss_m2 = 0.0 epoch_val_loss_m12 = 0.0 epoch_val_loss_m21 = 0.0 epoch_val_loss_mc = 0.0 m1_val_batch_pred_list = [] m2_val_batch_pred_list = [] m12_val_batch_pred_list = [] m21_val_batch_pred_list = [] mc_val_batch_pred_list = [] m1_val_batch_label_list = [] m2_val_batch_label_list = [] m12_val_batch_label_list = [] m21_val_batch_label_list = [] mc_val_batch_label_list = [] model.eval() with torch.no_grad(): for j, batch in tqdm(enumerate(val_dataloader)): img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) m1_pred = m1_pred.reshape(-1, 4) m2_pred = m2_pred.reshape(-1, 4) m12_pred = m12_pred.reshape(-1, 4) m21_pred = m21_pred.reshape(-1, 4) mc_pred = mc_pred.reshape(-1, 4) m1_label = labels[:, 0].reshape(-1).to(device) m2_label = labels[:, 1].reshape(-1).to(device) m12_label = labels[:, 2].reshape(-1).to(device) m21_label = labels[:, 3].reshape(-1).to(device) mc_label = labels[:, 4].reshape(-1).to(device) loss_m1 = ce_loss_m1(m1_pred, m1_label) loss_m2 = ce_loss_m2(m2_pred, m2_label) loss_m12 = ce_loss_m12(m12_pred, m12_label) loss_m21 = ce_loss_m21(m21_pred, m21_label) loss_mc = ce_loss_mc(mc_pred, mc_label) loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc epoch_val_loss_m1 += loss_m1.data.item() epoch_val_loss_m2 += loss_m2.data.item() epoch_val_loss_m12 += loss_m12.data.item() epoch_val_loss_m21 += loss_m21.data.item() epoch_val_loss_mc += loss_mc.data.item() m1_val_batch_pred_list.append(m1_pred) m2_val_batch_pred_list.append(m2_pred) m12_val_batch_pred_list.append(m12_pred) m21_val_batch_pred_list.append(m21_pred) mc_val_batch_pred_list.append(mc_pred) m1_val_batch_label_list.append(m1_label) m2_val_batch_label_list.append(m2_label) m12_val_batch_label_list.append(m12_label) m21_val_batch_label_list.append(m21_label) mc_val_batch_label_list.append(mc_label) if args.logger: wandb.log({'batch_val_loss': loss.data.item()}) print("Epoch {}/{} batch {}/{} validation done with loss={}".format( i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item()) ) val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores( torch.cat(m1_val_batch_pred_list), torch.cat(m1_val_batch_label_list), torch.cat(m2_val_batch_pred_list), torch.cat(m2_val_batch_label_list), torch.cat(m12_val_batch_pred_list), torch.cat(m12_val_batch_label_list), torch.cat(m21_val_batch_pred_list), torch.cat(m21_val_batch_label_list), torch.cat(mc_val_batch_pred_list), torch.cat(mc_val_batch_label_list) ) print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format( i+1, args.num_epoch, epoch_val_loss_m1/len(val_dataloader), epoch_val_loss_m2/len(val_dataloader), epoch_val_loss_m12/len(val_dataloader), epoch_val_loss_m21/len(val_dataloader), epoch_val_loss_mc/len(val_dataloader), val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score ) ) stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader)) stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader)) stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader)) stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader)) stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader)) stats['val']['m1_f1'].append(val_m1_f1_score) stats['val']['m2_f1'].append(val_m2_f1_score) stats['val']['m12_f1'].append(val_m12_f1_score) stats['val']['m21_f1'].append(val_m21_f1_score) stats['val']['mc_f1'].append(val_mc_f1_score) if args.logger: wandb.log( { 'val_m1_loss': epoch_val_loss_m1/len(val_dataloader), 'val_m2_loss': epoch_val_loss_m2/len(val_dataloader), 'val_m12_loss': epoch_val_loss_m12/len(val_dataloader), 'val_m21_loss': epoch_val_loss_m21/len(val_dataloader), 'val_mc_loss': epoch_val_loss_mc/len(val_dataloader), 'val_loss': epoch_val_loss_m1/len(val_dataloader) + \ epoch_val_loss_m2/len(val_dataloader) + \ epoch_val_loss_m12/len(val_dataloader) + \ epoch_val_loss_m21/len(val_dataloader) + \ epoch_val_loss_mc/len(val_dataloader), 'val_m1_f1_score': val_m1_f1_score, 'val_m2_f1_score': val_m2_f1_score, 'val_m12_f1_score': val_m12_f1_score, 'val_m21_f1_score': val_m21_f1_score, 'val_mc_f1_score': val_mc_f1_score } ) # check for best stat/model using validation accuracy if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1: max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] max_val_classification_epoch = i+1 torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model')) counter = 0 else: counter += 1 print(f'EarlyStopping counter: {counter} out of {args.patience}.') if counter >= args.patience: break with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f: f.write('{}\n'.format(CFG)) f.write('{}\n'.format(train_data)) f.write('{}\n'.format(val_data)) f.write('{}\n'.format(stats)) f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1)) f.close() print(f'Results saved in {experiment_save_path}') if __name__ == '__main__': parser = argparse.ArgumentParser() # Define the command-line arguments parser.add_argument('--gpu_id', type=int) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--logger', action='store_true') parser.add_argument('--presaved', type=int, default=128) parser.add_argument('--clip_grad_norm', type=float, default=0.5) parser.add_argument('--use_mixup', action='store_true') parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False) parser.add_argument('--non_blocking', action='store_true') parser.add_argument('--patience', type=int, default=99) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--pin_memory', action='store_true') parser.add_argument('--num_epoch', type=int, default=300) parser.add_argument('--lr', type=float, default=4e-4) parser.add_argument('--scheduler', type=str, default=None) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--weight_decay', type=float, default=0.005) parser.add_argument('--label_smoothing', type=float, default=0.1) parser.add_argument('--model_type', type=str) parser.add_argument('--aggr', type=str, default='concat', required=False) parser.add_argument('--use_resnet', action='store_true') parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--tom_weight', type=float, default=2.0, required=False) parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']) parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd') parser.add_argument('--save_path', type=str, default='experiments/') parser.add_argument('--resume_from_checkpoint', type=str, default=None) # Parse the command-line arguments args = parser.parse_args() if args.use_mixup and not args.mixup_alpha: parser.error("--use_mixup requires --mixup_alpha") if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': if not args.aggr: parser.error("The choosen --model_type requires --aggr") if args.model_type == 'tom_sl' and not args.tom_weight: parser.error("The choosen --model_type requires --tom_weight") # get experiment ID experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train' if not os.path.exists(args.save_path): os.makedirs(args.save_path, exist_ok=True) experiment_save_path = os.path.join(args.save_path, experiment_id) os.makedirs(experiment_save_path, exist_ok=True) CFG = { 'use_ocr_custom_loss': 0, 'presaved': args.presaved, 'batch_size': args.batch_size, 'num_epoch': args.num_epoch, 'lr': args.lr, 'scheduler': args.scheduler, 'weight_decay': args.weight_decay, 'model_type': args.model_type, 'use_resnet': args.use_resnet, 'hidden_dim': args.hidden_dim, 'tom_weight': args.tom_weight, 'dropout': args.dropout, 'label_smoothing': args.label_smoothing, 'clip_grad_norm': args.clip_grad_norm, 'use_mixup': args.use_mixup, 'mixup_alpha': args.mixup_alpha, 'non_blocking_tensors': args.non_blocking, 'patience': args.patience, 'pin_memory': args.pin_memory, 'resume_from_checkpoint': args.resume_from_checkpoint, 'aggr': args.aggr, 'mods': args.mods, 'save_path': experiment_save_path , 'seed': args.seed } print(CFG) print(f'Saving results in {experiment_save_path}') # set seed values if args.logger: wandb.init(project="tbd", config=CFG) os.environ['PYTHONHASHSEED'] = str(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) main(args)