mtomnet/boss/train.py
2025-01-10 15:39:20 +01:00

324 lines
16 KiB
Python

import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from dataloader import Data
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.tom_tf import TFToMnet
from models.single_mindnet import SingleMindNet
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
def train(args):
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
train_dataset = Data(
args.train_frame_path,
args.label_path,
args.train_pose_path,
args.train_gaze_path,
args.train_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
val_dataset = Data(
args.val_frame_path,
args.label_path,
args.val_pose_path,
args.val_gaze_path,
args.val_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_tf':
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
max_val_classification_acc = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_train_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
if args.use_mixup:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
else:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
if args.use_mixup:
labels = torch.reshape(labels, (-1, 54)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
else:
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
)
if scheduler: scheduler.step()
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_val_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
)
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
cls_acc = epoch_num_correct / epoch_cnt
stats['val']['cls_loss'].append(cls_loss)
stats['val']['cls_acc'].append(cls_acc)
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
# check for best stat/model using validation accuracy
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
max_val_classification_acc = stats['val']['cls_acc'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="boss", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
train(args)